Redo trie_by_cid package to be read-write

* use logrus instead of geth log
* remove benchmarks
* impl NodeIterator.ParentPath
* update go mods
This commit is contained in:
Roy Crihfield 2023-04-25 17:14:04 +08:00
parent c839739320
commit 7381b35dc6
43 changed files with 7237 additions and 707 deletions

5
go.mod
View File

@ -5,6 +5,7 @@ go 1.18
require ( require (
github.com/VictoriaMetrics/fastcache v1.6.0 github.com/VictoriaMetrics/fastcache v1.6.0
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha
github.com/davecgh/go-spew v1.1.1
github.com/ethereum/go-ethereum v1.11.5 github.com/ethereum/go-ethereum v1.11.5
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/ipfs/go-cid v0.2.0 github.com/ipfs/go-cid v0.2.0
@ -21,13 +22,13 @@ require (
github.com/DataDog/zstd v1.5.2 // indirect github.com/DataDog/zstd v1.5.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cockroachdb/errors v1.9.1 // indirect github.com/cockroachdb/errors v1.9.1 // indirect
github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect
github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect
github.com/cockroachdb/redact v1.1.3 // indirect github.com/cockroachdb/redact v1.1.3 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/deckarep/golang-set/v2 v2.1.0 // indirect github.com/deckarep/golang-set/v2 v2.1.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
github.com/edsrzf/mmap-go v1.0.0 // indirect github.com/edsrzf/mmap-go v1.0.0 // indirect
@ -112,4 +113,4 @@ require (
lukechampine.com/blake3 v1.1.6 // indirect lukechampine.com/blake3 v1.1.6 // indirect
) )
replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.4

7
go.sum
View File

@ -17,12 +17,13 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/btcsuite/btcd v0.22.0-beta h1:LTDpDKUM5EeOFBPM8IXpinEcmZ6FWfNZbE3lfrfdnWo=
github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k= github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k=
github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU= github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 h1:KdUfX2zKommPRa+PD0sWZUyXe9w277ABlgELO7H04IM=
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha h1:x9muuG0Z2W/UAkwAq+0SXhYG9MCP6SJlGEfFNQa6JPI= github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha h1:rhRmK/NeWMnQ07E4DuLb7WSh9FMotlXMPPaOrf8GJwM=
github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek= github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek=
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha h1:I1iXTaIjbTH8ehzNXmT2waXcYBifi1yjK6FK3W3a0Pg= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha h1:I1iXTaIjbTH8ehzNXmT2waXcYBifi1yjK6FK3W3a0Pg=
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha/go.mod h1:EGAdV/YewEADFDDVF1k9GNwy8vNWR29Xb87sRHgMIng= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha/go.mod h1:EGAdV/YewEADFDDVF1k9GNwy8vNWR29Xb87sRHgMIng=
github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk= github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk=

View File

@ -1,6 +1,10 @@
package internal package internal
import ( import (
"testing"
"time"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash" "github.com/multiformats/go-multihash"
) )
@ -12,3 +16,12 @@ func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) {
} }
return cid.NewCidV1(codec, buf), nil return cid.NewCidV1(codec, buf), nil
} }
// returns a cache config with unique name (groupcache names are global)
func MakeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
return pgipfsethdb.CacheConfig{
Name: t.Name(),
Size: 3000000, // 3MB
ExpiryDuration: time.Hour,
}
}

3
trie_by_cid/doc.go Normal file
View File

@ -0,0 +1,3 @@
// This package is a near complete copy of go-ethereum/trie and go-ethereum/core/state, modified to use
// a v0 IPFS blockstore as the backing DB, i.e. DB values are indexed by CID rather than hash.
package trie_by_cid

View File

@ -20,7 +20,10 @@ var (
mockTD = big.NewInt(1) mockTD = big.NewInt(1)
) )
func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { // IndexStateDiff indexes a single statediff.
// - uses TestChainConfig
// - block hash/number are left as zero
func IndexStateDiff(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error {
_, indexer, err := indexer.NewStateDiffIndexer( _, indexer, err := indexer.NewStateDiffIndexer(
context.Background(), ChainConfig, node.Info{}, dbConfig) context.Background(), ChainConfig, node.Info{}, dbConfig)
if err != nil { if err != nil {
@ -28,11 +31,10 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root
} }
defer indexer.Close() // fixme: hangs when using PGX driver defer indexer.Close() // fixme: hangs when using PGX driver
// generating statediff payload for block, and transform the data into Postgres
builder := statediff.NewBuilder(stateCache) builder := statediff.NewBuilder(stateCache)
block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher()) block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher())
// todo: use dummy block hashes to just produce trie structure for testing // uses zero block hash/number, we only need the trie structure here
args := statediff.Args{ args := statediff.Args{
OldStateRoot: rootA, OldStateRoot: rootA,
NewStateRoot: rootB, NewStateRoot: rootB,
@ -45,12 +47,7 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root
if err != nil { if err != nil {
return err return err
} }
// for _, node := range diff.Nodes { // we don't need to index diff.Nodes since we are just interested in the trie
// err := indexer.PushStateNode(tx, node, block.Hash().String())
// if err != nil {
// return err
// }
// }
for _, ipld := range diff.IPLDs { for _, ipld := range diff.IPLDs {
if err := indexer.PushIPLD(tx, ipld); err != nil { if err := indexer.PushIPLD(tx, ipld); err != nil {
return err return err

View File

@ -0,0 +1,136 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"github.com/ethereum/go-ethereum/common"
)
type accessList struct {
addresses map[common.Address]int
slots []map[common.Hash]struct{}
}
// ContainsAddress returns true if the address is in the access list.
func (al *accessList) ContainsAddress(address common.Address) bool {
_, ok := al.addresses[address]
return ok
}
// Contains checks if a slot within an account is present in the access list, returning
// separate flags for the presence of the account and the slot respectively.
func (al *accessList) Contains(address common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) {
idx, ok := al.addresses[address]
if !ok {
// no such address (and hence zero slots)
return false, false
}
if idx == -1 {
// address yes, but no slots
return true, false
}
_, slotPresent = al.slots[idx][slot]
return true, slotPresent
}
// newAccessList creates a new accessList.
func newAccessList() *accessList {
return &accessList{
addresses: make(map[common.Address]int),
}
}
// Copy creates an independent copy of an accessList.
func (a *accessList) Copy() *accessList {
cp := newAccessList()
for k, v := range a.addresses {
cp.addresses[k] = v
}
cp.slots = make([]map[common.Hash]struct{}, len(a.slots))
for i, slotMap := range a.slots {
newSlotmap := make(map[common.Hash]struct{}, len(slotMap))
for k := range slotMap {
newSlotmap[k] = struct{}{}
}
cp.slots[i] = newSlotmap
}
return cp
}
// AddAddress adds an address to the access list, and returns 'true' if the operation
// caused a change (addr was not previously in the list).
func (al *accessList) AddAddress(address common.Address) bool {
if _, present := al.addresses[address]; present {
return false
}
al.addresses[address] = -1
return true
}
// AddSlot adds the specified (addr, slot) combo to the access list.
// Return values are:
// - address added
// - slot added
// For any 'true' value returned, a corresponding journal entry must be made.
func (al *accessList) AddSlot(address common.Address, slot common.Hash) (addrChange bool, slotChange bool) {
idx, addrPresent := al.addresses[address]
if !addrPresent || idx == -1 {
// Address not present, or addr present but no slots there
al.addresses[address] = len(al.slots)
slotmap := map[common.Hash]struct{}{slot: {}}
al.slots = append(al.slots, slotmap)
return !addrPresent, true
}
// There is already an (address,slot) mapping
slotmap := al.slots[idx]
if _, ok := slotmap[slot]; !ok {
slotmap[slot] = struct{}{}
// Journal add slot change
return false, true
}
// No changes required
return false, false
}
// DeleteSlot removes an (address, slot)-tuple from the access list.
// This operation needs to be performed in the same order as the addition happened.
// This method is meant to be used by the journal, which maintains ordering of
// operations.
func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
idx, addrOk := al.addresses[address]
// There are two ways this can fail
if !addrOk {
panic("reverting slot change, address not present in list")
}
slotmap := al.slots[idx]
delete(slotmap, slot)
// If that was the last (first) slot, remove it
// Since additions and rollbacks are always performed in order,
// we can delete the item without worrying about screwing up later indices
if len(slotmap) == 0 {
al.slots = al.slots[:idx]
al.addresses[address] = -1
}
}
// DeleteAddress removes an address from the access list. This operation
// needs to be performed in the same order as the addition happened.
// This method is meant to be used by the journal, which maintains ordering of
// operations.
func (al *accessList) DeleteAddress(address common.Address) {
delete(al.addresses, address)
}

View File

@ -1,14 +1,30 @@
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state package state
import ( import (
"errors" "errors"
"fmt"
"github.com/VictoriaMetrics/fastcache"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/lru"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld" "github.com/ethereum/go-ethereum/statediff/indexer/ipld"
lru "github.com/hashicorp/golang-lru"
"github.com/cerc-io/ipld-eth-statedb/internal" "github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
@ -28,7 +44,10 @@ type Database interface {
OpenTrie(root common.Hash) (Trie, error) OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
OpenStorageTrie(addrHash, root common.Hash) (Trie, error) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error)
// CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
ContractCode(codeHash common.Hash) ([]byte, error) ContractCode(codeHash common.Hash) ([]byte, error)
@ -36,17 +55,79 @@ type Database interface {
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
ContractCodeSize(codeHash common.Hash) (int, error) ContractCodeSize(codeHash common.Hash) (int, error)
// DiskDB returns the underlying key-value disk database.
DiskDB() ethdb.KeyValueStore
// TrieDB retrieves the low level trie database used for data storage. // TrieDB retrieves the low level trie database used for data storage.
TrieDB() *trie.Database TrieDB() *trie.Database
} }
// Trie is a Ethereum Merkle Patricia trie. // Trie is a Ethereum Merkle Patricia trie.
type Trie interface { type Trie interface {
// GetKey returns the sha3 preimage of a hashed key that was previously used
// to store a value.
//
// TODO(fjl): remove this when StateTrie is removed
GetKey([]byte) []byte
// TryGet returns the value for key stored in the trie. The value bytes must
// not be modified by the caller. If a node was not found in the database, a
// trie.MissingNodeError is returned.
TryGet(key []byte) ([]byte, error) TryGet(key []byte) ([]byte, error)
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles.
TryGetNode(path []byte) ([]byte, int, error) TryGetNode(path []byte) ([]byte, int, error)
TryGetAccount(key []byte) (*types.StateAccount, error)
// TryGetAccount abstracts an account read from the trie. It retrieves the
// account blob from the trie with provided account address and decodes it
// with associated decoding algorithm. If the specified account is not in
// the trie, nil will be returned. If the trie is corrupted(e.g. some nodes
// are missing or the account blob is incorrect for decoding), an error will
// be returned.
TryGetAccount(address common.Address) (*types.StateAccount, error)
// TryUpdate associates key with value in the trie. If value has length zero, any
// existing value is deleted from the trie. The value bytes must not be modified
// by the caller while they are stored in the trie. If a node was not found in the
// database, a trie.MissingNodeError is returned.
TryUpdate(key, value []byte) error
// TryUpdateAccount abstracts an account write to the trie. It encodes the
// provided account object with associated algorithm and then updates it
// in the trie with provided address.
TryUpdateAccount(address common.Address, account *types.StateAccount) error
// TryDelete removes any existing value for key from the trie. If a node was not
// found in the database, a trie.MissingNodeError is returned.
TryDelete(key []byte) error
// TryDeleteAccount abstracts an account deletion from the trie.
TryDeleteAccount(address common.Address) error
// Hash returns the root hash of the trie. It does not write to the database and
// can be used even if the trie doesn't have one.
Hash() common.Hash Hash() common.Hash
// Commit collects all dirty nodes in the trie and replace them with the
// corresponding node hash. All collected nodes(including dirty leaves if
// collectLeaf is true) will be encapsulated into a nodeset for return.
// The returned nodeset can be nil if the trie is clean(nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
Commit(collectLeaf bool) (common.Hash, *trie.NodeSet)
// NodeIterator returns an iterator that returns nodes of the trie. Iteration
// starts at the key after the given start key.
NodeIterator(startKey []byte) trie.NodeIterator NodeIterator(startKey []byte) trie.NodeIterator
// Prove constructs a Merkle proof for key. The result contains all encoded nodes
// on the path to the value at key. The value itself is also included in the last
// node and can be retrieved by verifying the proof.
//
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root), ending
// with the node that proves the absence of the key.
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error
} }
@ -61,23 +142,34 @@ func NewDatabase(db ethdb.Database) Database {
// is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a
// large memory cache. // large memory cache.
func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{ return &cachingDB{
db: trie.NewDatabaseWithConfig(db, config), disk: db,
codeSizeCache: csc, codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: fastcache.New(codeCacheSize), codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
triedb: trie.NewDatabaseWithConfig(db, config),
}
}
// NewDatabaseWithNodeDB creates a state database with an already initialized node database.
func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database {
return &cachingDB{
disk: db,
codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
triedb: triedb,
} }
} }
type cachingDB struct { type cachingDB struct {
db *trie.Database disk ethdb.KeyValueStore
codeSizeCache *lru.Cache codeSizeCache *lru.Cache[common.Hash, int]
codeCache *fastcache.Cache codeCache *lru.SizeConstrainedCache[common.Hash, []byte]
triedb *trie.Database
} }
// OpenTrie opens the main account trie at a specific root hash. // OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
tr, err := trie.NewStateTrie(common.Hash{}, root, db.db) tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb, trie.StateTrieCodec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -85,29 +177,40 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
} }
// OpenStorageTrie opens the storage trie of an account. // OpenStorageTrie opens the storage trie of an account.
func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) {
tr, err := trie.NewStorageTrie(addrHash, root, db.db) tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb, trie.StorageTrieCodec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tr, nil return tr, nil
} }
// CopyTrie returns an independent copy of the given trie.
func (db *cachingDB) CopyTrie(t Trie) Trie {
switch t := t.(type) {
case *trie.StateTrie:
return t.Copy()
default:
panic(fmt.Errorf("unknown trie type %T", t))
}
}
// ContractCode retrieves a particular contract's code. // ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) { func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
if code := db.codeCache.Get(nil, codeHash.Bytes()); len(code) > 0 { code, _ := db.codeCache.Get(codeHash)
if len(code) > 0 {
return code, nil return code, nil
} }
codeCID, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes()) cid, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
code, err := db.db.DiskDB().Get(codeCID.Bytes()) code, err = db.disk.Get(cid.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(code) > 0 { if len(code) > 0 {
db.codeCache.Set(codeHash.Bytes(), code) db.codeCache.Add(codeHash, code)
db.codeSizeCache.Add(codeHash, len(code)) db.codeSizeCache.Add(codeHash, len(code))
return code, nil return code, nil
} }
@ -117,13 +220,18 @@ func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
// ContractCodeSize retrieves a particular contracts code's size. // ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) { func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok { if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached.(int), nil return cached, nil
} }
code, err := db.ContractCode(codeHash) code, err := db.ContractCode(codeHash)
return len(code), err return len(code), err
} }
// DiskDB returns the underlying key-value disk database.
func (db *cachingDB) DiskDB() ethdb.KeyValueStore {
return db.disk
}
// TrieDB retrieves any intermediate trie-node caching layer. // TrieDB retrieves any intermediate trie-node caching layer.
func (db *cachingDB) TrieDB() *trie.Database { func (db *cachingDB) TrieDB() *trie.Database {
return db.db return db.triedb
} }

View File

@ -0,0 +1,282 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"github.com/ethereum/go-ethereum/common"
)
// journalEntry is a modification entry in the state change journal that can be
// reverted on demand.
type journalEntry interface {
// revert undoes the changes introduced by this journal entry.
revert(*StateDB)
// dirtied returns the Ethereum address modified by this journal entry.
dirtied() *common.Address
}
// journal contains the list of state modifications applied since the last state
// commit. These are tracked to be able to be reverted in the case of an execution
// exception or request for reversal.
type journal struct {
entries []journalEntry // Current changes tracked by the journal
dirties map[common.Address]int // Dirty accounts and the number of changes
}
// newJournal creates a new initialized journal.
func newJournal() *journal {
return &journal{
dirties: make(map[common.Address]int),
}
}
// append inserts a new modification entry to the end of the change journal.
func (j *journal) append(entry journalEntry) {
j.entries = append(j.entries, entry)
if addr := entry.dirtied(); addr != nil {
j.dirties[*addr]++
}
}
// revert undoes a batch of journalled modifications along with any reverted
// dirty handling too.
func (j *journal) revert(statedb *StateDB, snapshot int) {
for i := len(j.entries) - 1; i >= snapshot; i-- {
// Undo the changes made by the operation
j.entries[i].revert(statedb)
// Drop any dirty tracking induced by the change
if addr := j.entries[i].dirtied(); addr != nil {
if j.dirties[*addr]--; j.dirties[*addr] == 0 {
delete(j.dirties, *addr)
}
}
}
j.entries = j.entries[:snapshot]
}
// dirty explicitly sets an address to dirty, even if the change entries would
// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD
// precompile consensus exception.
func (j *journal) dirty(addr common.Address) {
j.dirties[addr]++
}
// length returns the current number of entries in the journal.
func (j *journal) length() int {
return len(j.entries)
}
type (
// Changes to the account trie.
createObjectChange struct {
account *common.Address
}
resetObjectChange struct {
prev *stateObject
prevdestruct bool
}
suicideChange struct {
account *common.Address
prev bool // whether account had already suicided
prevbalance *big.Int
}
// Changes to individual accounts.
balanceChange struct {
account *common.Address
prev *big.Int
}
nonceChange struct {
account *common.Address
prev uint64
}
storageChange struct {
account *common.Address
key, prevalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
}
// Changes to other state values.
refundChange struct {
prev uint64
}
addLogChange struct {
txhash common.Hash
}
addPreimageChange struct {
hash common.Hash
}
touchChange struct {
account *common.Address
}
// Changes to the access list
accessListAddAccountChange struct {
address *common.Address
}
accessListAddSlotChange struct {
address *common.Address
slot *common.Hash
}
transientStorageChange struct {
account *common.Address
key, prevalue common.Hash
}
)
func (ch createObjectChange) revert(s *StateDB) {
delete(s.stateObjects, *ch.account)
delete(s.stateObjectsDirty, *ch.account)
}
func (ch createObjectChange) dirtied() *common.Address {
return ch.account
}
func (ch resetObjectChange) revert(s *StateDB) {
s.setStateObject(ch.prev)
if !ch.prevdestruct {
delete(s.stateObjectsDestruct, ch.prev.address)
}
}
func (ch resetObjectChange) dirtied() *common.Address {
return nil
}
func (ch suicideChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
obj.suicided = ch.prev
obj.setBalance(ch.prevbalance)
}
}
func (ch suicideChange) dirtied() *common.Address {
return ch.account
}
var ripemd = common.HexToAddress("0000000000000000000000000000000000000003")
func (ch touchChange) revert(s *StateDB) {
}
func (ch touchChange) dirtied() *common.Address {
return ch.account
}
func (ch balanceChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setBalance(ch.prev)
}
func (ch balanceChange) dirtied() *common.Address {
return ch.account
}
func (ch nonceChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setNonce(ch.prev)
}
func (ch nonceChange) dirtied() *common.Address {
return ch.account
}
func (ch codeChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
}
func (ch codeChange) dirtied() *common.Address {
return ch.account
}
func (ch storageChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setState(ch.key, ch.prevalue)
}
func (ch storageChange) dirtied() *common.Address {
return ch.account
}
func (ch transientStorageChange) revert(s *StateDB) {
s.setTransientState(*ch.account, ch.key, ch.prevalue)
}
func (ch transientStorageChange) dirtied() *common.Address {
return nil
}
func (ch refundChange) revert(s *StateDB) {
s.refund = ch.prev
}
func (ch refundChange) dirtied() *common.Address {
return nil
}
func (ch addLogChange) revert(s *StateDB) {
logs := s.logs[ch.txhash]
if len(logs) == 1 {
delete(s.logs, ch.txhash)
} else {
s.logs[ch.txhash] = logs[:len(logs)-1]
}
s.logSize--
}
func (ch addLogChange) dirtied() *common.Address {
return nil
}
func (ch addPreimageChange) revert(s *StateDB) {
delete(s.preimages, ch.hash)
}
func (ch addPreimageChange) dirtied() *common.Address {
return nil
}
func (ch accessListAddAccountChange) revert(s *StateDB) {
/*
One important invariant here, is that whenever a (addr, slot) is added, if the
addr is not already present, the add causes two journal entries:
- one for the address,
- one for the (address,slot)
Therefore, when unrolling the change, we can always blindly delete the
(addr) at this point, since no storage adds can remain when come upon
a single (addr) change.
*/
s.accessList.DeleteAddress(*ch.address)
}
func (ch accessListAddAccountChange) dirtied() *common.Address {
return nil
}
func (ch accessListAddSlotChange) revert(s *StateDB) {
s.accessList.DeleteSlot(*ch.address, *ch.slot)
}
func (ch accessListAddSlotChange) dirtied() *common.Address {
return nil
}

View File

@ -0,0 +1,30 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import "github.com/ethereum/go-ethereum/metrics"
var (
accountUpdatedMeter = metrics.NewRegisteredMeter("state/update/account", nil)
storageUpdatedMeter = metrics.NewRegisteredMeter("state/update/storage", nil)
accountDeletedMeter = metrics.NewRegisteredMeter("state/delete/account", nil)
storageDeletedMeter = metrics.NewRegisteredMeter("state/delete/storage", nil)
accountTrieUpdatedMeter = metrics.NewRegisteredMeter("state/update/accountnodes", nil)
storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil)
accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil)
storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil)
)

View File

@ -1,8 +1,25 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state package state
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"math/big" "math/big"
"time" "time"
@ -11,15 +28,7 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) // "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
var (
// emptyRoot is the known root hash of an empty trie.
// this is calculated as: emptyRoot = crypto.Keccak256(rlp.Encode([][]byte{}))
// that is, the keccak356 hash of the rlp encoding of an empty trie node (empty byte slice array)
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
// emptyCodeHash is the CodeHash for an EOA, for an account without contract code deployed
emptyCodeHash = crypto.Keccak256(nil)
) )
type Code []byte type Code []byte
@ -34,7 +43,6 @@ func (s Storage) String() (str string) {
for key, value := range s { for key, value := range s {
str += fmt.Sprintf("%X : %X\n", key, value) str += fmt.Sprintf("%X : %X\n", key, value)
} }
return return
} }
@ -43,32 +51,39 @@ func (s Storage) Copy() Storage {
for key, value := range s { for key, value := range s {
cpy[key] = value cpy[key] = value
} }
return cpy return cpy
} }
// stateObject represents an Ethereum account which is being accessed. // stateObject represents an Ethereum account which is being modified.
// //
// The usage pattern is as follows: // The usage pattern is as follows:
// First you need to obtain a state object. // First you need to obtain a state object.
// Account values can be accessed through the object. // Account values can be accessed and modified through the object.
type stateObject struct { type stateObject struct {
address common.Address address common.Address
addrHash common.Hash // hash of ethereum address of the account addrHash common.Hash // hash of ethereum address of the account
data types.StateAccount data types.StateAccount
db *StateDB db *StateDB
// Caches. // Write caches.
trie Trie // storage trie, which becomes non-nil on first access trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded code Code // contract bytecode, which gets set when code is loaded
originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
dirtyStorage Storage // Storage entries that have been modified in the current transaction execution
// Cache flags.
// When an object is marked suicided it will be deleted from the trie
// during the "update" phase of the state transition.
dirtyCode bool // true if the code was updated
suicided bool
deleted bool
} }
// empty returns whether the account is considered empty. // empty returns whether the account is considered empty.
func (s *stateObject) empty() bool { func (s *stateObject) empty() bool {
return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash) return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
} }
// newObject creates a state object. // newObject creates a state object.
@ -77,10 +92,10 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *st
data.Balance = new(big.Int) data.Balance = new(big.Int)
} }
if data.CodeHash == nil { if data.CodeHash == nil {
data.CodeHash = emptyCodeHash data.CodeHash = types.EmptyCodeHash.Bytes()
} }
if data.Root == (common.Hash{}) { if data.Root == (common.Hash{}) {
data.Root = emptyRoot data.Root = types.EmptyRootHash
} }
return &stateObject{ return &stateObject{
db: db, db: db,
@ -88,60 +103,117 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *st
addrHash: crypto.Keccak256Hash(address[:]), addrHash: crypto.Keccak256Hash(address[:]),
data: data, data: data,
originStorage: make(Storage), originStorage: make(Storage),
pendingStorage: make(Storage),
dirtyStorage: make(Storage),
} }
} }
// setError remembers the first non-nil error it is called with. // EncodeRLP implements rlp.Encoder.
func (s *stateObject) setError(err error) { func (s *stateObject) EncodeRLP(w io.Writer) error {
s.db.setError(err) return rlp.Encode(w, &s.data)
} }
func (s *stateObject) getTrie(db Database) Trie { func (s *stateObject) markSuicided() {
s.suicided = true
}
func (s *stateObject) touch() {
s.db.journal.append(touchChange{
account: &s.address,
})
if s.address == ripemd {
// Explicitly put it in the dirty-cache, which is otherwise generated from
// flattened journals.
s.db.journal.dirty(s.address)
}
}
// getTrie returns the associated storage trie. The trie will be opened
// if it's not loaded previously. An error will be returned if trie can't
// be loaded.
func (s *stateObject) getTrie(db Database) (Trie, error) {
if s.trie == nil { if s.trie == nil {
// // Try fetching from prefetcher first // Try fetching from prefetcher first
// // We don't prefetch empty tries // We don't prefetch empty tries
// if s.data.Root != emptyRoot && s.db.prefetcher != nil { if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil {
// // When the miner is creating the pending state, there is no // When the miner is creating the pending state, there is no
// // prefetcher // prefetcher
// s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
// } }
if s.trie == nil { if s.trie == nil {
var err error tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root)
s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root)
if err != nil { if err != nil {
s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) return nil, err
s.setError(fmt.Errorf("can't create storage trie: %w", err)) }
s.trie = tr
} }
} }
return s.trie, nil
} }
return s.trie
// GetState retrieves a value from the account storage trie.
func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
// If we have a dirty value for this state entry, return it
value, dirty := s.dirtyStorage[key]
if dirty {
return value
}
// Otherwise return the entry's original value
return s.GetCommittedState(db, key)
} }
// GetCommittedState retrieves a value from the committed account storage trie. // GetCommittedState retrieves a value from the committed account storage trie.
func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash {
// If the fake storage is set, only lookup the state here(in the debugging mode) // If we have a pending write or clean cached, return that
if s.fakeStorage != nil { if value, pending := s.pendingStorage[key]; pending {
return s.fakeStorage[key] return value
} }
// If we have a cached value, return that
if value, cached := s.originStorage[key]; cached { if value, cached := s.originStorage[key]; cached {
return value return value
} }
// If no live objects are available, load from the database. // If the object was destructed in *this* block (and potentially resurrected),
// the storage has been cleared out, and we should *not* consult the previous
// database about any storage values. The only possible alternatives are:
// 1) resurrect happened, and new slot values were set -- those should
// have been handles via pendingStorage above.
// 2) we don't have new values, and can deliver empty response back
if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed {
return common.Hash{}
}
// If no live objects are available, attempt to use snapshots
var (
enc []byte
err error
)
if s.db.snap != nil {
start := time.Now() start := time.Now()
enc, err := s.getTrie(db).TryGet(key.Bytes()) enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes()))
if metrics.EnabledExpensive {
s.db.SnapshotStorageReads += time.Since(start)
}
}
// If the snapshot is unavailable or reading from it fails, load from the database.
if s.db.snap == nil || err != nil {
start := time.Now()
tr, err := s.getTrie(db)
if err != nil {
s.db.setError(err)
return common.Hash{}
}
enc, err = tr.TryGet(key.Bytes())
if metrics.EnabledExpensive { if metrics.EnabledExpensive {
s.db.StorageReads += time.Since(start) s.db.StorageReads += time.Since(start)
} }
if err != nil { if err != nil {
s.setError(err) s.db.setError(err)
return common.Hash{} return common.Hash{}
} }
}
var value common.Hash var value common.Hash
if len(enc) > 0 { if len(enc) > 0 {
_, content, _, err := rlp.Split(enc) _, content, _, err := rlp.Split(enc)
if err != nil { if err != nil {
s.setError(err) s.db.setError(err)
} }
value.SetBytes(content) value.SetBytes(content)
} }
@ -149,6 +221,182 @@ func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
return value return value
} }
// SetState updates a value in account storage.
func (s *stateObject) SetState(db Database, key, value common.Hash) {
// If the new value is the same as old, don't set
prev := s.GetState(db, key)
if prev == value {
return
}
// New value is different, update and journal the change
s.db.journal.append(storageChange{
account: &s.address,
key: key,
prevalue: prev,
})
s.setState(key, value)
}
func (s *stateObject) setState(key, value common.Hash) {
s.dirtyStorage[key] = value
}
// finalise moves all dirty storage slots into the pending area to be hashed or
// committed later. It is invoked at the end of every transaction.
func (s *stateObject) finalise(prefetch bool) {
slotsToPrefetch := make([][]byte, 0, len(s.dirtyStorage))
for key, value := range s.dirtyStorage {
s.pendingStorage[key] = value
if value != s.originStorage[key] {
slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure
}
}
if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash {
s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch)
}
if len(s.dirtyStorage) > 0 {
s.dirtyStorage = make(Storage)
}
}
// updateTrie writes cached storage modifications into the object's storage trie.
// It will return nil if the trie has not been loaded and no changes have been
// made. An error will be returned if the trie can't be loaded/updated correctly.
func (s *stateObject) updateTrie(db Database) (Trie, error) {
// Make sure all dirty slots are finalized into the pending storage area
s.finalise(false) // Don't prefetch anymore, pull directly if need be
if len(s.pendingStorage) == 0 {
return s.trie, nil
}
// Track the amount of time wasted on updating the storage trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now())
}
// The snapshot storage map for the object
var (
storage map[common.Hash][]byte
hasher = s.db.hasher
)
tr, err := s.getTrie(db)
if err != nil {
s.db.setError(err)
return nil, err
}
// Insert all the pending updates into the trie
usedStorage := make([][]byte, 0, len(s.pendingStorage))
for key, value := range s.pendingStorage {
// Skip noop changes, persist actual changes
if value == s.originStorage[key] {
continue
}
s.originStorage[key] = value
var v []byte
if (value == common.Hash{}) {
if err := tr.TryDelete(key[:]); err != nil {
s.db.setError(err)
return nil, err
}
s.db.StorageDeleted += 1
} else {
// Encoding []byte cannot fail, ok to ignore the error.
v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:]))
if err := tr.TryUpdate(key[:], v); err != nil {
s.db.setError(err)
return nil, err
}
s.db.StorageUpdated += 1
}
// If state snapshotting is active, cache the data til commit
if s.db.snap != nil {
if storage == nil {
// Retrieve the old storage map, if available, create a new one otherwise
if storage = s.db.snapStorage[s.addrHash]; storage == nil {
storage = make(map[common.Hash][]byte)
s.db.snapStorage[s.addrHash] = storage
}
}
storage[crypto.HashData(hasher, key[:])] = v // v will be nil if it's deleted
}
usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure
}
if s.db.prefetcher != nil {
s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage)
}
if len(s.pendingStorage) > 0 {
s.pendingStorage = make(Storage)
}
return tr, nil
}
// UpdateRoot sets the trie root to the current root hash of. An error
// will be returned if trie root hash is not computed correctly.
func (s *stateObject) updateRoot(db Database) {
tr, err := s.updateTrie(db)
if err != nil {
return
}
// If nothing changed, don't bother with hashing anything
if tr == nil {
return
}
// Track the amount of time wasted on hashing the storage trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now())
}
s.data.Root = tr.Hash()
}
// AddBalance adds amount to s's balance.
// It is used to add funds to the destination account of a transfer.
func (s *stateObject) AddBalance(amount *big.Int) {
// EIP161: We must check emptiness for the objects such that the account
// clearing (0,0,0 objects) can take effect.
if amount.Sign() == 0 {
if s.empty() {
s.touch()
}
return
}
s.SetBalance(new(big.Int).Add(s.Balance(), amount))
}
// SubBalance removes amount from s's balance.
// It is used to remove funds from the origin account of a transfer.
func (s *stateObject) SubBalance(amount *big.Int) {
if amount.Sign() == 0 {
return
}
s.SetBalance(new(big.Int).Sub(s.Balance(), amount))
}
func (s *stateObject) SetBalance(amount *big.Int) {
s.db.journal.append(balanceChange{
account: &s.address,
prev: new(big.Int).Set(s.data.Balance),
})
s.setBalance(amount)
}
func (s *stateObject) setBalance(amount *big.Int) {
s.data.Balance = amount
}
func (s *stateObject) deepCopy(db *StateDB) *stateObject {
stateObject := newObject(db, s.address, s.data)
if s.trie != nil {
stateObject.trie = db.db.CopyTrie(s.trie)
}
stateObject.code = s.code
stateObject.dirtyStorage = s.dirtyStorage.Copy()
stateObject.originStorage = s.originStorage.Copy()
stateObject.pendingStorage = s.pendingStorage.Copy()
stateObject.suicided = s.suicided
stateObject.dirtyCode = s.dirtyCode
stateObject.deleted = s.deleted
return stateObject
}
// //
// Attribute accessors // Attribute accessors
// //
@ -163,12 +411,12 @@ func (s *stateObject) Code(db Database) []byte {
if s.code != nil { if s.code != nil {
return s.code return s.code
} }
if bytes.Equal(s.CodeHash(), emptyCodeHash) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil return nil
} }
code, err := db.ContractCode(common.BytesToHash(s.CodeHash())) code, err := db.ContractCode(common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
} }
s.code = code s.code = code
return code return code
@ -181,16 +429,44 @@ func (s *stateObject) CodeSize(db Database) int {
if s.code != nil { if s.code != nil {
return len(s.code) return len(s.code)
} }
if bytes.Equal(s.CodeHash(), emptyCodeHash) { if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0 return 0
} }
size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash())) size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash()))
if err != nil { if err != nil {
s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
} }
return size return size
} }
func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := s.Code(s.db.db)
s.db.journal.append(codeChange{
account: &s.address,
prevhash: s.CodeHash(),
prevcode: prevcode,
})
s.setCode(codeHash, code)
}
func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
s.code = code
s.data.CodeHash = codeHash[:]
s.dirtyCode = true
}
func (s *stateObject) SetNonce(nonce uint64) {
s.db.journal.append(nonceChange{
account: &s.address,
prev: s.data.Nonce,
})
s.setNonce(nonce)
}
func (s *stateObject) setNonce(nonce uint64) {
s.data.Nonce = nonce
}
func (s *stateObject) CodeHash() []byte { func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash return s.data.CodeHash
} }
@ -202,10 +478,3 @@ func (s *stateObject) Balance() *big.Int {
func (s *stateObject) Nonce() uint64 { func (s *stateObject) Nonce() uint64 {
return s.data.Nonce return s.data.Nonce
} }
// Never called, but must be present to allow stateObject to be used
// as a vm.Account interface that also satisfies the vm.ContractRef
// interface. Interfaces are awesome.
func (s *stateObject) Value() *big.Int {
panic("Value on stateObject should never be called")
}

View File

@ -0,0 +1,46 @@
// Copyright 2019 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"bytes"
"testing"
"github.com/ethereum/go-ethereum/common"
)
func BenchmarkCutOriginal(b *testing.B) {
value := common.HexToHash("0x01")
for i := 0; i < b.N; i++ {
bytes.TrimLeft(value[:], "\x00")
}
}
func BenchmarkCutsetterFn(b *testing.B) {
value := common.HexToHash("0x01")
cutSetFn := func(r rune) bool { return r == 0 }
for i := 0; i < b.N; i++ {
bytes.TrimLeftFunc(value[:], cutSetFn)
}
}
func BenchmarkCutCustomTrim(b *testing.B) {
value := common.HexToHash("0x01")
for i := 0; i < b.N; i++ {
common.TrimLeftZeroes(value[:])
}
}

View File

@ -0,0 +1,219 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"bytes"
"context"
"math/big"
"testing"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
"github.com/cerc-io/ipld-eth-statedb/internal"
)
var (
testCtx = context.Background()
testConfig, _ = postgres.DefaultConfig.WithEnv()
teardownStatements = []string{`TRUNCATE ipld.blocks`}
)
type stateTest struct {
db ethdb.Database
state *StateDB
}
func newStateTest(t *testing.T) *stateTest {
pool, err := postgres.ConnectSQLX(testCtx, testConfig)
if err != nil {
t.Fatal(err)
}
db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t))
sdb, err := New(common.Hash{}, NewDatabase(db), nil)
if err != nil {
t.Fatal(err)
}
return &stateTest{db: db, state: sdb}
}
func TestNull(t *testing.T) {
s := newStateTest(t)
address := common.HexToAddress("0x823140710bf13990e4500136726d8b55")
s.state.CreateAccount(address)
//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash
s.state.SetState(address, common.Hash{}, value)
// s.state.Commit(false)
if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) {
t.Errorf("expected empty current value, got %x", value)
}
if value := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) {
t.Errorf("expected empty committed value, got %x", value)
}
}
func TestSnapshot(t *testing.T) {
stateobjaddr := common.BytesToAddress([]byte("aa"))
var storageaddr common.Hash
data1 := common.BytesToHash([]byte{42})
data2 := common.BytesToHash([]byte{43})
s := newStateTest(t)
// snapshot the genesis state
genesis := s.state.Snapshot()
// set initial state object value
s.state.SetState(stateobjaddr, storageaddr, data1)
snapshot := s.state.Snapshot()
// set a new state object value, revert it and ensure correct content
s.state.SetState(stateobjaddr, storageaddr, data2)
s.state.RevertToSnapshot(snapshot)
if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 {
t.Errorf("wrong storage value %v, want %v", v, data1)
}
if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) {
t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{})
}
// revert up to the genesis state and ensure correct content
s.state.RevertToSnapshot(genesis)
if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) {
t.Errorf("wrong storage value %v, want %v", v, common.Hash{})
}
if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) {
t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{})
}
}
func TestSnapshotEmpty(t *testing.T) {
s := newStateTest(t)
s.state.RevertToSnapshot(s.state.Snapshot())
}
func TestSnapshot2(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateobjaddr0 := common.BytesToAddress([]byte("so0"))
stateobjaddr1 := common.BytesToAddress([]byte("so1"))
var storageaddr common.Hash
data0 := common.BytesToHash([]byte{17})
data1 := common.BytesToHash([]byte{18})
state.SetState(stateobjaddr0, storageaddr, data0)
state.SetState(stateobjaddr1, storageaddr, data1)
// db, trie are already non-empty values
so0 := state.getStateObject(stateobjaddr0)
so0.SetBalance(big.NewInt(42))
so0.SetNonce(43)
so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
so0.suicided = false
so0.deleted = false
state.setStateObject(so0)
// root, _ := state.Commit(false)
// state, _ = New(root, state.db, state.snaps)
// and one with deleted == true
so1 := state.getStateObject(stateobjaddr1)
so1.SetBalance(big.NewInt(52))
so1.SetNonce(53)
so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
so1.suicided = true
so1.deleted = true
state.setStateObject(so1)
so1 = state.getStateObject(stateobjaddr1)
if so1 != nil {
t.Fatalf("deleted object not nil when getting")
}
snapshot := state.Snapshot()
state.RevertToSnapshot(snapshot)
so0Restored := state.getStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing.
so0Restored.GetState(state.db, storageaddr)
so0Restored.Code(state.db)
// non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t)
// deleted should be nil, both before and after restore of state copy
so1Restored := state.getStateObject(stateobjaddr1)
if so1Restored != nil {
t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored)
}
}
func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
if so0.Address() != so1.Address() {
t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
}
if so0.Balance().Cmp(so1.Balance()) != 0 {
t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance())
}
if so0.Nonce() != so1.Nonce() {
t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce())
}
if so0.data.Root != so1.data.Root {
t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:])
}
if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) {
t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash())
}
if !bytes.Equal(so0.code, so1.code) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
}
if len(so1.dirtyStorage) != len(so0.dirtyStorage) {
t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage))
}
for k, v := range so1.dirtyStorage {
if so0.dirtyStorage[k] != v {
t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v)
}
}
for k, v := range so0.dirtyStorage {
if so1.dirtyStorage[k] != v {
t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v)
}
}
if len(so1.originStorage) != len(so0.originStorage) {
t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage))
}
for k, v := range so1.originStorage {
if so0.originStorage[k] != v {
t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v)
}
}
for k, v := range so0.originStorage {
if so1.originStorage[k] != v {
t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v)
}
}
}

View File

@ -1,17 +1,45 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// Package state provides a caching layer atop the Ethereum state trie.
package state package state
import ( import (
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"sort"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rlp"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
type revision struct {
id int
journalIndex int
}
type proofList [][]byte type proofList [][]byte
func (n *proofList) Put(key []byte, value []byte) error { func (n *proofList) Put(key []byte, value []byte) error {
@ -28,32 +56,76 @@ func (n *proofList) Delete(key []byte) error {
// nested states. It's the general query interface to retrieve: // nested states. It's the general query interface to retrieve:
// * Contracts // * Contracts
// * Accounts // * Accounts
//
// This implementation is read-only and performs no journaling, prefetching, or metrics tracking.
type StateDB struct { type StateDB struct {
db Database db Database
prefetcher *triePrefetcher
trie Trie trie Trie
hasher crypto.KeccakState hasher crypto.KeccakState
// originalRoot is the pre-state root, before any changes were made.
// It will be updated when the Commit is called.
originalRoot common.Hash
snaps *snapshot.Tree
snap snapshot.Snapshot
snapAccounts map[common.Hash][]byte
snapStorage map[common.Hash]map[common.Hash][]byte
// This map holds 'live' objects, which will get modified while processing a state transition. // This map holds 'live' objects, which will get modified while processing a state transition.
stateObjects map[common.Address]*stateObject stateObjects map[common.Address]*stateObject
stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution
stateObjectsDestruct map[common.Address]struct{} // State objects destructed in the block
// DB error. // DB error.
// State objects are used by the consensus core and VM which are // State objects are used by the consensus core and VM which are
// unable to deal with database-level errors. Any error that occurs // unable to deal with database-level errors. Any error that occurs
// during a database read is memoized here and will eventually be returned // during a database read is memoized here and will eventually be
// by StateDB.Commit. // returned by StateDB.Commit. Notably, this error is also shared
// by all cached state objects in case the database failure occurs
// when accessing state of accounts.
dbErr error dbErr error
// The refund counter, also used by state transitioning.
refund uint64
thash common.Hash
txIndex int
logs map[common.Hash][]*types.Log
logSize uint
preimages map[common.Hash][]byte preimages map[common.Hash][]byte
// Per-transaction access list
accessList *accessList
// Transient storage
transientStorage transientStorage
// Journal of state modifications. This is the backbone of
// Snapshot and RevertToSnapshot.
journal *journal
validRevisions []revision
nextRevisionId int
// Measurements gathered during execution for debugging purposes // Measurements gathered during execution for debugging purposes
AccountReads time.Duration AccountReads time.Duration
AccountHashes time.Duration
AccountUpdates time.Duration
StorageReads time.Duration StorageReads time.Duration
StorageHashes time.Duration
StorageUpdates time.Duration
SnapshotAccountReads time.Duration
SnapshotStorageReads time.Duration
AccountUpdated int
StorageUpdated int
AccountDeleted int
StorageDeleted int
} }
// New creates a new state from a given trie. // New creates a new state from a given trie.
func New(root common.Hash, db Database) (*StateDB, error) { func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) {
tr, err := db.OpenTrie(root) tr, err := db.OpenTrie(root)
if err != nil { if err != nil {
return nil, err return nil, err
@ -61,13 +133,50 @@ func New(root common.Hash, db Database) (*StateDB, error) {
sdb := &StateDB{ sdb := &StateDB{
db: db, db: db,
trie: tr, trie: tr,
originalRoot: root,
snaps: snaps,
stateObjects: make(map[common.Address]*stateObject), stateObjects: make(map[common.Address]*stateObject),
stateObjectsPending: make(map[common.Address]struct{}),
stateObjectsDirty: make(map[common.Address]struct{}),
stateObjectsDestruct: make(map[common.Address]struct{}),
logs: make(map[common.Hash][]*types.Log),
preimages: make(map[common.Hash][]byte), preimages: make(map[common.Hash][]byte),
journal: newJournal(),
accessList: newAccessList(),
transientStorage: newTransientStorage(),
hasher: crypto.NewKeccakState(), hasher: crypto.NewKeccakState(),
} }
if sdb.snaps != nil {
if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil {
sdb.snapAccounts = make(map[common.Hash][]byte)
sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
}
}
return sdb, nil return sdb, nil
} }
// StartPrefetcher initializes a new trie prefetcher to pull in nodes from the
// state trie concurrently while the state is mutated so that when we reach the
// commit phase, most of the needed data is already hot.
func (s *StateDB) StartPrefetcher(namespace string) {
if s.prefetcher != nil {
s.prefetcher.close()
s.prefetcher = nil
}
if s.snap != nil {
s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace)
}
}
// StopPrefetcher terminates a running prefetcher and reports any leftover stats
// from the gathered metrics.
func (s *StateDB) StopPrefetcher() {
if s.prefetcher != nil {
s.prefetcher.close()
s.prefetcher = nil
}
}
// setError remembers the first non-nil error it is called with. // setError remembers the first non-nil error it is called with.
func (s *StateDB) setError(err error) { func (s *StateDB) setError(err error) {
if s.dbErr == nil { if s.dbErr == nil {
@ -75,17 +184,44 @@ func (s *StateDB) setError(err error) {
} }
} }
// Error returns the memorized database failure occurred earlier.
func (s *StateDB) Error() error { func (s *StateDB) Error() error {
return s.dbErr return s.dbErr
} }
func (s *StateDB) AddLog(log *types.Log) { func (s *StateDB) AddLog(log *types.Log) {
panic("unsupported") s.journal.append(addLogChange{txhash: s.thash})
log.TxHash = s.thash
log.TxIndex = uint(s.txIndex)
log.Index = s.logSize
s.logs[s.thash] = append(s.logs[s.thash], log)
s.logSize++
}
// GetLogs returns the logs matching the specified transaction hash, and annotates
// them with the given blockNumber and blockHash.
func (s *StateDB) GetLogs(hash common.Hash, blockNumber uint64, blockHash common.Hash) []*types.Log {
logs := s.logs[hash]
for _, l := range logs {
l.BlockNumber = blockNumber
l.BlockHash = blockHash
}
return logs
}
func (s *StateDB) Logs() []*types.Log {
var logs []*types.Log
for _, lgs := range s.logs {
logs = append(logs, lgs...)
}
return logs
} }
// AddPreimage records a SHA3 preimage seen by the VM. // AddPreimage records a SHA3 preimage seen by the VM.
func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) { func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) {
if _, ok := s.preimages[hash]; !ok { if _, ok := s.preimages[hash]; !ok {
s.journal.append(addPreimageChange{hash: hash})
pi := make([]byte, len(preimage)) pi := make([]byte, len(preimage))
copy(pi, preimage) copy(pi, preimage)
s.preimages[hash] = pi s.preimages[hash] = pi
@ -99,13 +235,18 @@ func (s *StateDB) Preimages() map[common.Hash][]byte {
// AddRefund adds gas to the refund counter // AddRefund adds gas to the refund counter
func (s *StateDB) AddRefund(gas uint64) { func (s *StateDB) AddRefund(gas uint64) {
panic("unsupported") s.journal.append(refundChange{prev: s.refund})
s.refund += gas
} }
// SubRefund removes gas from the refund counter. // SubRefund removes gas from the refund counter.
// This method will panic if the refund counter goes below zero // This method will panic if the refund counter goes below zero
func (s *StateDB) SubRefund(gas uint64) { func (s *StateDB) SubRefund(gas uint64) {
panic("unsupported") s.journal.append(refundChange{prev: s.refund})
if gas > s.refund {
panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund))
}
s.refund -= gas
} }
// Exist reports whether the given account address exists in the state. // Exist reports whether the given account address exists in the state.
@ -139,6 +280,11 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 {
return 0 return 0
} }
// TxIndex returns the current transaction index set by Prepare.
func (s *StateDB) TxIndex() int {
return s.txIndex
}
func (s *StateDB) GetCode(addr common.Address) []byte { func (s *StateDB) GetCode(addr common.Address) []byte {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
@ -186,18 +332,28 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) {
// GetStorageProof returns the Merkle proof for given storage slot. // GetStorageProof returns the Merkle proof for given storage slot.
func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) {
var proof proofList trie, err := s.StorageTrie(a)
trie := s.StorageTrie(a) if err != nil {
if trie == nil { return nil, err
return proof, errors.New("storage trie for requested address does not exist")
} }
err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) if trie == nil {
return proof, err return nil, errors.New("storage trie for requested address does not exist")
}
var proof proofList
err = trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
if err != nil {
return nil, err
}
return proof, nil
} }
// GetCommittedState retrieves a value from the given account's committed storage trie. // GetCommittedState retrieves a value from the given account's committed storage trie.
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
return s.GetState(addr, hash) stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.GetCommittedState(s.db, hash)
}
return common.Hash{}
} }
// Database retrieves the low level database supporting the lower level trie ops. // Database retrieves the low level database supporting the lower level trie ops.
@ -205,17 +361,26 @@ func (s *StateDB) Database() Database {
return s.db return s.db
} }
// StorageTrie returns the storage trie of an account. // StorageTrie returns the storage trie of an account. The return value is a copy
// The return value is a copy and is nil for non-existent accounts. // and is nil for non-existent accounts. An error will be returned if storage trie
func (s *StateDB) StorageTrie(addr common.Address) Trie { // is existent but can't be loaded correctly.
func (s *StateDB) StorageTrie(addr common.Address) (Trie, error) {
stateObject := s.getStateObject(addr) stateObject := s.getStateObject(addr)
if stateObject == nil { if stateObject == nil {
return nil return nil, nil
} }
return stateObject.getTrie(s.db) cpy := stateObject.deepCopy(s)
if _, err := cpy.updateTrie(s.db); err != nil {
return nil, err
}
return cpy.getTrie(s.db)
} }
func (s *StateDB) HasSuicided(addr common.Address) bool { func (s *StateDB) HasSuicided(addr common.Address) bool {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return stateObject.suicided
}
return false return false
} }
@ -225,34 +390,61 @@ func (s *StateDB) HasSuicided(addr common.Address) bool {
// AddBalance adds amount to the account associated with addr. // AddBalance adds amount to the account associated with addr.
func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) { func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
panic("unsupported") stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.AddBalance(amount)
}
} }
// SubBalance subtracts amount from the account associated with addr. // SubBalance subtracts amount from the account associated with addr.
func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) { func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) {
panic("unsupported") stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SubBalance(amount)
}
} }
func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) { func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) {
panic("unsupported") stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetBalance(amount)
}
} }
func (s *StateDB) SetNonce(addr common.Address, nonce uint64) { func (s *StateDB) SetNonce(addr common.Address, nonce uint64) {
panic("unsupported") stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetNonce(nonce)
}
} }
func (s *StateDB) SetCode(addr common.Address, code []byte) { func (s *StateDB) SetCode(addr common.Address, code []byte) {
panic("unsupported") stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetCode(crypto.Keccak256Hash(code), code)
}
} }
func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { func (s *StateDB) SetState(addr common.Address, key, value common.Hash) {
panic("unsupported") stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
stateObject.SetState(s.db, key, value)
}
} }
// SetStorage replaces the entire storage for the specified account with given // SetStorage replaces the entire storage for the specified account with given
// storage. This function should only be used for debugging. // storage. This function should only be used for debugging.
func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) { func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) {
panic("unsupported") // SetStorage needs to wipe existing storage. We achieve this by pretending
// that the account self-destructed earlier in this block, by flagging
// it in stateObjectsDestruct. The effect of doing so is that storage lookups
// will not hit disk, since it is assumed that the disk-data is belonging
// to a previous incarnation of the object.
s.stateObjectsDestruct[addr] = struct{}{}
stateObject := s.GetOrNewStateObject(addr)
for k, v := range storage {
stateObject.SetState(s.db, k, v)
}
} }
// Suicide marks the given account as suicided. // Suicide marks the given account as suicided.
@ -261,36 +453,147 @@ func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common
// The account's state object is still available until the state is committed, // The account's state object is still available until the state is committed,
// getStateObject will return a non-nil account after Suicide. // getStateObject will return a non-nil account after Suicide.
func (s *StateDB) Suicide(addr common.Address) bool { func (s *StateDB) Suicide(addr common.Address) bool {
panic("unsupported") stateObject := s.getStateObject(addr)
if stateObject == nil {
return false return false
} }
s.journal.append(suicideChange{
account: &addr,
prev: stateObject.suicided,
prevbalance: new(big.Int).Set(stateObject.Balance()),
})
stateObject.markSuicided()
stateObject.data.Balance = new(big.Int)
return true
}
// SetTransientState sets transient storage for a given account. It
// adds the change to the journal so that it can be rolled back
// to its previous value if there is a revert.
func (s *StateDB) SetTransientState(addr common.Address, key, value common.Hash) {
prev := s.GetTransientState(addr, key)
if prev == value {
return
}
s.journal.append(transientStorageChange{
account: &addr,
key: key,
prevalue: prev,
})
s.setTransientState(addr, key, value)
}
// setTransientState is a lower level setter for transient storage. It
// is called during a revert to prevent modifications to the journal.
func (s *StateDB) setTransientState(addr common.Address, key, value common.Hash) {
s.transientStorage.Set(addr, key, value)
}
// GetTransientState gets transient storage for a given account.
func (s *StateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash {
return s.transientStorage.Get(addr, key)
}
// //
// Setting, updating & deleting state object methods. // Setting, updating & deleting state object methods.
// //
// updateStateObject writes the given object to the trie.
func (s *StateDB) updateStateObject(obj *stateObject) {
// Track the amount of time wasted on updating the account from the trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now())
}
// Encode the account and update the account trie
addr := obj.Address()
if err := s.trie.TryUpdateAccount(addr, &obj.data); err != nil {
s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err))
}
// If state snapshotting is active, cache the data til commit. Note, this
// update mechanism is not symmetric to the deletion, because whereas it is
// enough to track account updates at commit time, deletions need tracking
// at transaction boundary level to ensure we capture state clearing.
if s.snap != nil {
s.snapAccounts[obj.addrHash] = snapshot.SlimAccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash)
}
}
// deleteStateObject removes the given object from the state trie.
func (s *StateDB) deleteStateObject(obj *stateObject) {
// Track the amount of time wasted on deleting the account from the trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now())
}
// Delete the account from the trie
addr := obj.Address()
if err := s.trie.TryDeleteAccount(addr); err != nil {
s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err))
}
}
// getStateObject retrieves a state object given by the address, returning nil if // getStateObject retrieves a state object given by the address, returning nil if
// the object is not found or was deleted in this execution context. // the object is not found or was deleted in this execution context. If you need
// to differentiate between non-existent/just-deleted, use getDeletedStateObject.
func (s *StateDB) getStateObject(addr common.Address) *stateObject { func (s *StateDB) getStateObject(addr common.Address) *stateObject {
if obj := s.getDeletedStateObject(addr); obj != nil && !obj.deleted {
return obj
}
return nil
}
// getDeletedStateObject is similar to getStateObject, but instead of returning
// nil for a deleted state object, it returns the actual object with the deleted
// flag set. This is needed by the state journal to revert to the correct s-
// destructed object instead of wiping all knowledge about the state object.
func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
// Prefer live objects if any is available // Prefer live objects if any is available
if obj := s.stateObjects[addr]; obj != nil { if obj := s.stateObjects[addr]; obj != nil {
return obj return obj
} }
// If no live objects are available, load from the database // If no live objects are available, attempt to use snapshots
var data *types.StateAccount
if s.snap != nil {
start := time.Now()
acc, err := s.snap.Account(crypto.HashData(s.hasher, addr.Bytes()))
if metrics.EnabledExpensive {
s.SnapshotAccountReads += time.Since(start)
}
if err == nil {
if acc == nil {
return nil
}
data = &types.StateAccount{
Nonce: acc.Nonce,
Balance: acc.Balance,
CodeHash: acc.CodeHash,
Root: common.BytesToHash(acc.Root),
}
if len(data.CodeHash) == 0 {
data.CodeHash = types.EmptyCodeHash.Bytes()
}
if data.Root == (common.Hash{}) {
data.Root = types.EmptyRootHash
}
}
}
// If snapshot unavailable or reading from it failed, load from the database
if data == nil {
start := time.Now() start := time.Now()
var err error var err error
data, err := s.trie.TryGetAccount(addr.Bytes()) data, err = s.trie.TryGetAccount(addr)
if metrics.EnabledExpensive { if metrics.EnabledExpensive {
s.AccountReads += time.Since(start) s.AccountReads += time.Since(start)
} }
if err != nil { if err != nil {
s.setError(fmt.Errorf("getStateObject (%x) error: %w", addr.Bytes(), err)) s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err))
return nil return nil
} }
if data == nil { if data == nil {
return nil return nil
} }
}
// Insert into the live set // Insert into the live set
obj := newObject(s, addr, *data) obj := newObject(s, addr, *data)
s.setStateObject(obj) s.setStateObject(obj)
@ -301,6 +604,40 @@ func (s *StateDB) setStateObject(object *stateObject) {
s.stateObjects[object.Address()] = object s.stateObjects[object.Address()] = object
} }
// GetOrNewStateObject retrieves a state object or create a new state object if nil.
func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject {
stateObject := s.getStateObject(addr)
if stateObject == nil {
stateObject, _ = s.createObject(addr)
}
return stateObject
}
// createObject creates a new state object. If there is an existing account with
// the given address, it is overwritten and returned as the second return value.
func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) {
prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that!
var prevdestruct bool
if prev != nil {
_, prevdestruct = s.stateObjectsDestruct[prev.address]
if !prevdestruct {
s.stateObjectsDestruct[prev.address] = struct{}{}
}
}
newobj = newObject(s, addr, types.StateAccount{})
if prev == nil {
s.journal.append(createObjectChange{account: &addr})
} else {
s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct})
}
s.setStateObject(newobj)
if prev != nil && !prev.deleted {
return newobj, prev
}
return newobj, nil
}
// CreateAccount explicitly creates a state object. If a state object with the address // CreateAccount explicitly creates a state object. If a state object with the address
// already exists the balance is carried over to the new account. // already exists the balance is carried over to the new account.
// //
@ -312,58 +649,394 @@ func (s *StateDB) setStateObject(object *stateObject) {
// //
// Carrying over the balance ensures that Ether doesn't disappear. // Carrying over the balance ensures that Ether doesn't disappear.
func (s *StateDB) CreateAccount(addr common.Address) { func (s *StateDB) CreateAccount(addr common.Address) {
panic("unsupported") newObj, prev := s.createObject(addr)
if prev != nil {
newObj.setBalance(prev.data.Balance)
}
} }
func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error {
so := db.getStateObject(addr)
if so == nil {
return nil return nil
} }
tr, err := so.getTrie(db.db)
if err != nil {
return err
}
it := trie.NewIterator(tr.NodeIterator(nil))
for it.Next() {
key := common.BytesToHash(db.trie.GetKey(it.Key))
if value, dirty := so.dirtyStorage[key]; dirty {
if !cb(key, value) {
return nil
}
continue
}
if len(it.Value) > 0 {
_, content, _, err := rlp.Split(it.Value)
if err != nil {
return err
}
if !cb(key, common.BytesToHash(content)) {
return nil
}
}
}
return nil
}
// Copy creates a deep, independent copy of the state.
// Snapshots of the copied state cannot be applied to the copy.
func (s *StateDB) Copy() *StateDB {
// Copy all the basic fields, initialize the memory ones
state := &StateDB{
db: s.db,
trie: s.db.CopyTrie(s.trie),
originalRoot: s.originalRoot,
stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)),
stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)),
stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)),
stateObjectsDestruct: make(map[common.Address]struct{}, len(s.stateObjectsDestruct)),
refund: s.refund,
logs: make(map[common.Hash][]*types.Log, len(s.logs)),
logSize: s.logSize,
preimages: make(map[common.Hash][]byte, len(s.preimages)),
journal: newJournal(),
hasher: crypto.NewKeccakState(),
}
// Copy the dirty states, logs, and preimages
for addr := range s.journal.dirties {
// As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527),
// and in the Finalise-method, there is a case where an object is in the journal but not
// in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for
// nil
if object, exist := s.stateObjects[addr]; exist {
// Even though the original object is dirty, we are not copying the journal,
// so we need to make sure that any side-effect the journal would have caused
// during a commit (or similar op) is already applied to the copy.
state.stateObjects[addr] = object.deepCopy(state)
state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits
state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits
}
}
// Above, we don't copy the actual journal. This means that if the copy
// is copied, the loop above will be a no-op, since the copy's journal
// is empty. Thus, here we iterate over stateObjects, to enable copies
// of copies.
for addr := range s.stateObjectsPending {
if _, exist := state.stateObjects[addr]; !exist {
state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state)
}
state.stateObjectsPending[addr] = struct{}{}
}
for addr := range s.stateObjectsDirty {
if _, exist := state.stateObjects[addr]; !exist {
state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state)
}
state.stateObjectsDirty[addr] = struct{}{}
}
// Deep copy the destruction flag.
for addr := range s.stateObjectsDestruct {
state.stateObjectsDestruct[addr] = struct{}{}
}
for hash, logs := range s.logs {
cpy := make([]*types.Log, len(logs))
for i, l := range logs {
cpy[i] = new(types.Log)
*cpy[i] = *l
}
state.logs[hash] = cpy
}
for hash, preimage := range s.preimages {
state.preimages[hash] = preimage
}
// Do we need to copy the access list and transient storage?
// In practice: No. At the start of a transaction, these two lists are empty.
// In practice, we only ever copy state _between_ transactions/blocks, never
// in the middle of a transaction. However, it doesn't cost us much to copy
// empty lists, so we do it anyway to not blow up if we ever decide copy them
// in the middle of a transaction.
state.accessList = s.accessList.Copy()
state.transientStorage = s.transientStorage.Copy()
// If there's a prefetcher running, make an inactive copy of it that can
// only access data but does not actively preload (since the user will not
// know that they need to explicitly terminate an active copy).
if s.prefetcher != nil {
state.prefetcher = s.prefetcher.copy()
}
if s.snaps != nil {
// In order for the miner to be able to use and make additions
// to the snapshot tree, we need to copy that as well.
// Otherwise, any block mined by ourselves will cause gaps in the tree,
// and force the miner to operate trie-backed only
state.snaps = s.snaps
state.snap = s.snap
// deep copy needed
state.snapAccounts = make(map[common.Hash][]byte)
for k, v := range s.snapAccounts {
state.snapAccounts[k] = v
}
state.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
for k, v := range s.snapStorage {
temp := make(map[common.Hash][]byte)
for kk, vv := range v {
temp[kk] = vv
}
state.snapStorage[k] = temp
}
}
return state
}
// Snapshot returns an identifier for the current revision of the state. // Snapshot returns an identifier for the current revision of the state.
func (s *StateDB) Snapshot() int { func (s *StateDB) Snapshot() int {
return 0 id := s.nextRevisionId
s.nextRevisionId++
s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()})
return id
} }
// RevertToSnapshot reverts all state changes made since the given revision. // RevertToSnapshot reverts all state changes made since the given revision.
func (s *StateDB) RevertToSnapshot(revid int) { func (s *StateDB) RevertToSnapshot(revid int) {
panic("unsupported") // Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(s.validRevisions), func(i int) bool {
return s.validRevisions[i].id >= revid
})
if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid {
panic(fmt.Errorf("revision id %v cannot be reverted", revid))
}
snapshot := s.validRevisions[idx].journalIndex
// Replay the journal to undo changes and remove invalidated snapshots
s.journal.revert(s, snapshot)
s.validRevisions = s.validRevisions[:idx]
} }
// GetRefund returns the current value of the refund counter. // GetRefund returns the current value of the refund counter.
func (s *StateDB) GetRefund() uint64 { func (s *StateDB) GetRefund() uint64 {
panic("unsupported") return s.refund
return 0
} }
// PrepareAccessList handles the preparatory steps for executing a state transition with // Finalise finalises the state by removing the destructed objects and clears
// regards to both EIP-2929 and EIP-2930: // the journal as well as the refunds. Finalise, however, will not push any updates
// into the tries just yet. Only IntermediateRoot or Commit will do that.
func (s *StateDB) Finalise(deleteEmptyObjects bool) {
addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties))
for addr := range s.journal.dirties {
obj, exist := s.stateObjects[addr]
if !exist {
// ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2
// That tx goes out of gas, and although the notion of 'touched' does not exist there, the
// touch-event will still be recorded in the journal. Since ripeMD is a special snowflake,
// it will persist in the journal even though the journal is reverted. In this special circumstance,
// it may exist in `s.journal.dirties` but not in `s.stateObjects`.
// Thus, we can safely ignore it here
continue
}
if obj.suicided || (deleteEmptyObjects && obj.empty()) {
obj.deleted = true
// We need to maintain account deletions explicitly (will remain
// set indefinitely).
s.stateObjectsDestruct[obj.address] = struct{}{}
// If state snapshotting is active, also mark the destruction there.
// Note, we can't do this only at the end of a block because multiple
// transactions within the same block might self destruct and then
// resurrect an account; but the snapshotter needs both events.
if s.snap != nil {
delete(s.snapAccounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect)
delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect)
}
} else {
obj.finalise(true) // Prefetch slots in the background
}
s.stateObjectsPending[addr] = struct{}{}
s.stateObjectsDirty[addr] = struct{}{}
// At this point, also ship the address off to the precacher. The precacher
// will start loading tries, and when the change is eventually committed,
// the commit-phase will be a lot faster
addressesToPrefetch = append(addressesToPrefetch, common.CopyBytes(addr[:])) // Copy needed for closure
}
if s.prefetcher != nil && len(addressesToPrefetch) > 0 {
s.prefetcher.prefetch(common.Hash{}, s.originalRoot, addressesToPrefetch)
}
// Invalidate journal because reverting across transactions is not allowed.
s.clearJournalAndRefund()
}
// IntermediateRoot computes the current root hash of the state trie.
// It is called in between transactions to get the root hash that
// goes into transaction receipts.
func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
// Finalise all the dirty storage states and write them into the tries
s.Finalise(deleteEmptyObjects)
// If there was a trie prefetcher operating, it gets aborted and irrevocably
// modified after we start retrieving tries. Remove it from the statedb after
// this round of use.
// //
// This is weird pre-byzantium since the first tx runs with a prefetcher and
// the remainder without, but pre-byzantium even the initial prefetcher is
// useless, so no sleep lost.
prefetcher := s.prefetcher
if s.prefetcher != nil {
defer func() {
s.prefetcher.close()
s.prefetcher = nil
}()
}
// Although naively it makes sense to retrieve the account trie and then do
// the contract storage and account updates sequentially, that short circuits
// the account prefetcher. Instead, let's process all the storage updates
// first, giving the account prefetches just a few more milliseconds of time
// to pull useful data from disk.
for addr := range s.stateObjectsPending {
if obj := s.stateObjects[addr]; !obj.deleted {
obj.updateRoot(s.db)
}
}
// Now we're about to start to write changes to the trie. The trie is so far
// _untouched_. We can check with the prefetcher, if it can give us a trie
// which has the same root, but also has some content loaded into it.
if prefetcher != nil {
if trie := prefetcher.trie(common.Hash{}, s.originalRoot); trie != nil {
s.trie = trie
}
}
usedAddrs := make([][]byte, 0, len(s.stateObjectsPending))
for addr := range s.stateObjectsPending {
if obj := s.stateObjects[addr]; obj.deleted {
s.deleteStateObject(obj)
s.AccountDeleted += 1
} else {
s.updateStateObject(obj)
s.AccountUpdated += 1
}
usedAddrs = append(usedAddrs, common.CopyBytes(addr[:])) // Copy needed for closure
}
if prefetcher != nil {
prefetcher.used(common.Hash{}, s.originalRoot, usedAddrs)
}
if len(s.stateObjectsPending) > 0 {
s.stateObjectsPending = make(map[common.Address]struct{})
}
// Track the amount of time wasted on hashing the account trie
if metrics.EnabledExpensive {
defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now())
}
return s.trie.Hash()
}
// SetTxContext sets the current transaction hash and index which are
// used when the EVM emits new state logs. It should be invoked before
// transaction execution.
func (s *StateDB) SetTxContext(thash common.Hash, ti int) {
s.thash = thash
s.txIndex = ti
}
func (s *StateDB) clearJournalAndRefund() {
if len(s.journal.entries) > 0 {
s.journal = newJournal()
s.refund = 0
}
s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries
}
// Prepare handles the preparatory steps for executing a state transition with.
// This method must be invoked before state transition.
//
// Berlin fork:
// - Add sender to access list (2929) // - Add sender to access list (2929)
// - Add destination to access list (2929) // - Add destination to access list (2929)
// - Add precompiles to access list (2929) // - Add precompiles to access list (2929)
// - Add the contents of the optional tx access list (2930) // - Add the contents of the optional tx access list (2930)
// //
// This method should only be called if Berlin/2929+2930 is applicable at the current number. // Potential EIPs:
func (s *StateDB) PrepareAccessList(sender common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) { // - Reset access list (Berlin)
panic("unsupported") // - Add coinbase to access list (EIP-3651)
// - Reset transient storage (EIP-1153)
func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) {
if rules.IsBerlin {
// Clear out any leftover from previous executions
al := newAccessList()
s.accessList = al
al.AddAddress(sender)
if dst != nil {
al.AddAddress(*dst)
// If it's a create-tx, the destination will be added inside evm.create
}
for _, addr := range precompiles {
al.AddAddress(addr)
}
for _, el := range list {
al.AddAddress(el.Address)
for _, key := range el.StorageKeys {
al.AddSlot(el.Address, key)
}
}
if rules.IsShanghai { // EIP-3651: warm coinbase
al.AddAddress(coinbase)
}
}
// Reset transient storage at the beginning of transaction execution
s.transientStorage = newTransientStorage()
} }
// AddAddressToAccessList adds the given address to the access list // AddAddressToAccessList adds the given address to the access list
func (s *StateDB) AddAddressToAccessList(addr common.Address) { func (s *StateDB) AddAddressToAccessList(addr common.Address) {
panic("unsupported") if s.accessList.AddAddress(addr) {
s.journal.append(accessListAddAccountChange{&addr})
}
} }
// AddSlotToAccessList adds the given (address, slot)-tuple to the access list // AddSlotToAccessList adds the given (address, slot)-tuple to the access list
func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) {
panic("unsupported") addrMod, slotMod := s.accessList.AddSlot(addr, slot)
if addrMod {
// In practice, this should not happen, since there is no way to enter the
// scope of 'address' without having the 'address' become already added
// to the access list (via call-variant, create, etc).
// Better safe than sorry, though
s.journal.append(accessListAddAccountChange{&addr})
}
if slotMod {
s.journal.append(accessListAddSlotChange{
address: &addr,
slot: &slot,
})
}
} }
// AddressInAccessList returns true if the given address is in the access list. // AddressInAccessList returns true if the given address is in the access list.
func (s *StateDB) AddressInAccessList(addr common.Address) bool { func (s *StateDB) AddressInAccessList(addr common.Address) bool {
return false return s.accessList.ContainsAddress(addr)
} }
// SlotInAccessList returns true if the given (address, slot)-tuple is in the access list. // SlotInAccessList returns true if the given (address, slot)-tuple is in the access list.
func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) {
return return s.accessList.Contains(addr, slot)
}
// convertAccountSet converts a provided account set from address keyed to hash keyed.
func (s *StateDB) convertAccountSet(set map[common.Address]struct{}) map[common.Hash]struct{} {
ret := make(map[common.Hash]struct{})
for addr := range set {
obj, exist := s.stateObjects[addr]
if !exist {
ret[crypto.Keccak256Hash(addr[:])] = struct{}{}
} else {
ret[obj.addrHash] = struct{}{}
}
}
return ret
} }

View File

@ -0,0 +1,594 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"math/big"
"math/rand"
"reflect"
"strings"
"sync"
"testing"
"testing/quick"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
)
// TestCopy tests that copying a StateDB object indeed makes the original and
// the copy independent of each other. This test is a regression test against
// https://github.com/ethereum/go-ethereum/pull/15549.
func TestCopy(t *testing.T) {
// Create a random state test to copy and modify "independently"
orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
for i := byte(0); i < 255; i++ {
obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
obj.AddBalance(big.NewInt(int64(i)))
orig.updateStateObject(obj)
}
orig.Finalise(false)
// Copy the state
copy := orig.Copy()
// Copy the copy state
ccopy := copy.Copy()
// modify all in memory
for i := byte(0); i < 255; i++ {
origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
origObj.AddBalance(big.NewInt(2 * int64(i)))
copyObj.AddBalance(big.NewInt(3 * int64(i)))
ccopyObj.AddBalance(big.NewInt(4 * int64(i)))
orig.updateStateObject(origObj)
copy.updateStateObject(copyObj)
ccopy.updateStateObject(copyObj)
}
// Finalise the changes on all concurrently
finalise := func(wg *sync.WaitGroup, db *StateDB) {
defer wg.Done()
db.Finalise(true)
}
var wg sync.WaitGroup
wg.Add(3)
go finalise(&wg, orig)
go finalise(&wg, copy)
go finalise(&wg, ccopy)
wg.Wait()
// Verify that the three states have been updated independently
for i := byte(0); i < 255; i++ {
origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
}
if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
}
if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 {
t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want)
}
}
}
func TestSnapshotRandom(t *testing.T) {
config := &quick.Config{MaxCount: 1000}
err := quick.Check((*snapshotTest).run, config)
if cerr, ok := err.(*quick.CheckError); ok {
test := cerr.In[0].(*snapshotTest)
t.Errorf("%v:\n%s", test.err, test)
} else if err != nil {
t.Error(err)
}
}
// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
// captured by the snapshot. Instances of this test with pseudorandom content are created
// by Generate.
//
// The test works as follows:
//
// A new state is created and all actions are applied to it. Several snapshots are taken
// in between actions. The test then reverts each snapshot. For each snapshot the actions
// leading up to it are replayed on a fresh, empty state. The behaviour of all public
// accessor methods on the reverted state must match the return value of the equivalent
// methods on the replayed state.
type snapshotTest struct {
addrs []common.Address // all account addresses
actions []testAction // modifications to the state
snapshots []int // actions indexes at which snapshot is taken
err error // failure details are reported through this field
}
type testAction struct {
name string
fn func(testAction, *StateDB)
args []int64
noAddr bool
}
// newTestAction creates a random action that changes state.
func newTestAction(addr common.Address, r *rand.Rand) testAction {
actions := []testAction{
{
name: "SetBalance",
fn: func(a testAction, s *StateDB) {
s.SetBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "AddBalance",
fn: func(a testAction, s *StateDB) {
s.AddBalance(addr, big.NewInt(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetNonce",
fn: func(a testAction, s *StateDB) {
s.SetNonce(addr, uint64(a.args[0]))
},
args: make([]int64, 1),
},
{
name: "SetState",
fn: func(a testAction, s *StateDB) {
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
s.SetState(addr, key, val)
},
args: make([]int64, 2),
},
{
name: "SetCode",
fn: func(a testAction, s *StateDB) {
code := make([]byte, 16)
binary.BigEndian.PutUint64(code, uint64(a.args[0]))
binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
s.SetCode(addr, code)
},
args: make([]int64, 2),
},
{
name: "CreateAccount",
fn: func(a testAction, s *StateDB) {
s.CreateAccount(addr)
},
},
{
name: "Suicide",
fn: func(a testAction, s *StateDB) {
s.Suicide(addr)
},
},
{
name: "AddRefund",
fn: func(a testAction, s *StateDB) {
s.AddRefund(uint64(a.args[0]))
},
args: make([]int64, 1),
noAddr: true,
},
{
name: "AddLog",
fn: func(a testAction, s *StateDB) {
data := make([]byte, 2)
binary.BigEndian.PutUint16(data, uint16(a.args[0]))
s.AddLog(&types.Log{Address: addr, Data: data})
},
args: make([]int64, 1),
},
{
name: "AddPreimage",
fn: func(a testAction, s *StateDB) {
preimage := []byte{1}
hash := common.BytesToHash(preimage)
s.AddPreimage(hash, preimage)
},
args: make([]int64, 1),
},
{
name: "AddAddressToAccessList",
fn: func(a testAction, s *StateDB) {
s.AddAddressToAccessList(addr)
},
},
{
name: "AddSlotToAccessList",
fn: func(a testAction, s *StateDB) {
s.AddSlotToAccessList(addr,
common.Hash{byte(a.args[0])})
},
args: make([]int64, 1),
},
{
name: "SetTransientState",
fn: func(a testAction, s *StateDB) {
var key, val common.Hash
binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
s.SetTransientState(addr, key, val)
},
args: make([]int64, 2),
},
}
action := actions[r.Intn(len(actions))]
var nameargs []string
if !action.noAddr {
nameargs = append(nameargs, addr.Hex())
}
for i := range action.args {
action.args[i] = rand.Int63n(100)
nameargs = append(nameargs, fmt.Sprint(action.args[i]))
}
action.name += strings.Join(nameargs, ", ")
return action
}
// Generate returns a new snapshot test of the given size. All randomness is
// derived from r.
func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
// Generate random actions.
addrs := make([]common.Address, 50)
for i := range addrs {
addrs[i][0] = byte(i)
}
actions := make([]testAction, size)
for i := range actions {
addr := addrs[r.Intn(len(addrs))]
actions[i] = newTestAction(addr, r)
}
// Generate snapshot indexes.
nsnapshots := int(math.Sqrt(float64(size)))
if size > 0 && nsnapshots == 0 {
nsnapshots = 1
}
snapshots := make([]int, nsnapshots)
snaplen := len(actions) / nsnapshots
for i := range snapshots {
// Try to place the snapshots some number of actions apart from each other.
snapshots[i] = (i * snaplen) + r.Intn(snaplen)
}
return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
}
func (test *snapshotTest) String() string {
out := new(bytes.Buffer)
sindex := 0
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
sindex++
}
fmt.Fprintf(out, "%4d: %s\n", i, action.name)
}
return out.String()
}
func (test *snapshotTest) run() bool {
// Run all actions and create snapshots.
var (
state, _ = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
snapshotRevs = make([]int, len(test.snapshots))
sindex = 0
)
for i, action := range test.actions {
if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
snapshotRevs[sindex] = state.Snapshot()
sindex++
}
action.fn(action, state)
}
// Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- {
checkstate, _ := New(common.Hash{}, state.Database(), nil)
for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate)
}
state.RevertToSnapshot(snapshotRevs[sindex])
if err := test.checkEqual(state, checkstate); err != nil {
test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
return false
}
}
return true
}
// checkEqual checks that methods of state and checkstate return the same values.
func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
for _, addr := range test.addrs {
var err error
checkeq := func(op string, a, b interface{}) bool {
if err == nil && !reflect.DeepEqual(a, b) {
err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
return false
}
return true
}
// Check basic accessor methods.
checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check storage.
if obj := state.getStateObject(addr); obj != nil {
state.ForEachStorage(addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
checkstate.ForEachStorage(addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
}
if err != nil {
return err
}
}
if state.GetRefund() != checkstate.GetRefund() {
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
state.GetRefund(), checkstate.GetRefund())
}
if !reflect.DeepEqual(state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
}
return nil
}
// TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
// See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512
func TestCopyOfCopy(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
addr := common.HexToAddress("aaaa")
state.SetBalance(addr, big.NewInt(42))
if got := state.Copy().GetBalance(addr).Uint64(); got != 42 {
t.Fatalf("1st copy fail, expected 42, got %v", got)
}
if got := state.Copy().Copy().GetBalance(addr).Uint64(); got != 42 {
t.Fatalf("2nd copy fail, expected 42, got %v", got)
}
}
func TestStateDBAccessList(t *testing.T) {
// Some helpers
addr := func(a string) common.Address {
return common.HexToAddress(a)
}
slot := func(a string) common.Hash {
return common.HexToHash(a)
}
memDb := rawdb.NewMemoryDatabase()
db := NewDatabase(memDb)
state, _ := New(common.Hash{}, db, nil)
state.accessList = newAccessList()
verifyAddrs := func(astrings ...string) {
t.Helper()
// convert to common.Address form
var addresses []common.Address
var addressMap = make(map[common.Address]struct{})
for _, astring := range astrings {
address := addr(astring)
addresses = append(addresses, address)
addressMap[address] = struct{}{}
}
// Check that the given addresses are in the access list
for _, address := range addresses {
if !state.AddressInAccessList(address) {
t.Fatalf("expected %x to be in access list", address)
}
}
// Check that only the expected addresses are present in the access list
for address := range state.accessList.addresses {
if _, exist := addressMap[address]; !exist {
t.Fatalf("extra address %x in access list", address)
}
}
}
verifySlots := func(addrString string, slotStrings ...string) {
if !state.AddressInAccessList(addr(addrString)) {
t.Fatalf("scope missing address/slots %v", addrString)
}
var address = addr(addrString)
// convert to common.Hash form
var slots []common.Hash
var slotMap = make(map[common.Hash]struct{})
for _, slotString := range slotStrings {
s := slot(slotString)
slots = append(slots, s)
slotMap[s] = struct{}{}
}
// Check that the expected items are in the access list
for i, s := range slots {
if _, slotPresent := state.SlotInAccessList(address, s); !slotPresent {
t.Fatalf("input %d: scope missing slot %v (address %v)", i, s, addrString)
}
}
// Check that no extra elements are in the access list
index := state.accessList.addresses[address]
if index >= 0 {
stateSlots := state.accessList.slots[index]
for s := range stateSlots {
if _, slotPresent := slotMap[s]; !slotPresent {
t.Fatalf("scope has extra slot %v (address %v)", s, addrString)
}
}
}
}
state.AddAddressToAccessList(addr("aa")) // 1
state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3
state.AddSlotToAccessList(addr("bb"), slot("02")) // 4
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
// Make a copy
stateCopy1 := state.Copy()
if exp, got := 4, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
// same again, should cause no journal entries
state.AddSlotToAccessList(addr("bb"), slot("01"))
state.AddSlotToAccessList(addr("bb"), slot("02"))
state.AddAddressToAccessList(addr("aa"))
if exp, got := 4, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
// some new ones
state.AddSlotToAccessList(addr("bb"), slot("03")) // 5
state.AddSlotToAccessList(addr("aa"), slot("01")) // 6
state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8
state.AddAddressToAccessList(addr("cc"))
if exp, got := 8, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
verifyAddrs("aa", "bb", "cc")
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
verifySlots("cc", "01")
// now start rolling back changes
state.journal.revert(state, 7)
if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb", "cc")
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
state.journal.revert(state, 6)
if state.AddressInAccessList(addr("cc")) {
t.Fatalf("addr present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("aa", "01")
verifySlots("bb", "01", "02", "03")
state.journal.revert(state, 5)
if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02", "03")
state.journal.revert(state, 4)
if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
state.journal.revert(state, 3)
if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
verifySlots("bb", "01")
state.journal.revert(state, 2)
if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok {
t.Fatalf("slot present, expected missing")
}
verifyAddrs("aa", "bb")
state.journal.revert(state, 1)
if state.AddressInAccessList(addr("bb")) {
t.Fatalf("addr present, expected missing")
}
verifyAddrs("aa")
state.journal.revert(state, 0)
if state.AddressInAccessList(addr("aa")) {
t.Fatalf("addr present, expected missing")
}
if got, exp := len(state.accessList.addresses), 0; got != exp {
t.Fatalf("expected empty, got %d", got)
}
if got, exp := len(state.accessList.slots), 0; got != exp {
t.Fatalf("expected empty, got %d", got)
}
// Check the copy
// Make a copy
state = stateCopy1
verifyAddrs("aa", "bb")
verifySlots("bb", "01", "02")
if got, exp := len(state.accessList.addresses), 2; got != exp {
t.Fatalf("expected empty, got %d", got)
}
if got, exp := len(state.accessList.slots), 1; got != exp {
t.Fatalf("expected empty, got %d", got)
}
}
func TestStateDBTransientStorage(t *testing.T) {
memDb := rawdb.NewMemoryDatabase()
db := NewDatabase(memDb)
state, _ := New(common.Hash{}, db, nil)
key := common.Hash{0x01}
value := common.Hash{0x02}
addr := common.Address{}
state.SetTransientState(addr, key, value)
if exp, got := 1, state.journal.length(); exp != got {
t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
}
// the retrieved value should equal what was set
if got := state.GetTransientState(addr, key); got != value {
t.Fatalf("transient storage mismatch: have %x, want %x", got, value)
}
// revert the transient state being set and then check that the
// value is now the empty hash
state.journal.revert(state, 0)
if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got {
t.Fatalf("transient storage mismatch: have %x, want %x", got, exp)
}
// set transient state and then copy the statedb and ensure that
// the transient state is copied
state.SetTransientState(addr, key, value)
cpy := state.Copy()
if got := cpy.GetTransientState(addr, key); got != value {
t.Fatalf("transient storage mismatch: have %x, want %x", got, value)
}
}

View File

@ -0,0 +1,55 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"github.com/ethereum/go-ethereum/common"
)
// transientStorage is a representation of EIP-1153 "Transient Storage".
type transientStorage map[common.Address]Storage
// newTransientStorage creates a new instance of a transientStorage.
func newTransientStorage() transientStorage {
return make(transientStorage)
}
// Set sets the transient-storage `value` for `key` at the given `addr`.
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
if _, ok := t[addr]; !ok {
t[addr] = make(Storage)
}
t[addr][key] = value
}
// Get gets the transient storage for `key` at the given `addr`.
func (t transientStorage) Get(addr common.Address, key common.Hash) common.Hash {
val, ok := t[addr]
if !ok {
return common.Hash{}
}
return val[key]
}
// Copy does a deep copy of the transientStorage
func (t transientStorage) Copy() transientStorage {
storage := make(transientStorage)
for key, value := range t {
storage[key] = value.Copy()
}
return storage
}

View File

@ -0,0 +1,354 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/metrics"
log "github.com/sirupsen/logrus"
)
var (
// triePrefetchMetricsPrefix is the prefix under which to publish the metrics.
triePrefetchMetricsPrefix = "trie/prefetch/"
)
// triePrefetcher is an active prefetcher, which receives accounts or storage
// items and does trie-loading of them. The goal is to get as much useful content
// into the caches as possible.
//
// Note, the prefetcher's API is not thread safe.
type triePrefetcher struct {
db Database // Database to fetch trie nodes through
root common.Hash // Root hash of the account trie for metrics
fetches map[string]Trie // Partially or fully fetcher tries
fetchers map[string]*subfetcher // Subfetchers for each trie
deliveryMissMeter metrics.Meter
accountLoadMeter metrics.Meter
accountDupMeter metrics.Meter
accountSkipMeter metrics.Meter
accountWasteMeter metrics.Meter
storageLoadMeter metrics.Meter
storageDupMeter metrics.Meter
storageSkipMeter metrics.Meter
storageWasteMeter metrics.Meter
}
func newTriePrefetcher(db Database, root common.Hash, namespace string) *triePrefetcher {
prefix := triePrefetchMetricsPrefix + namespace
p := &triePrefetcher{
db: db,
root: root,
fetchers: make(map[string]*subfetcher), // Active prefetchers use the fetchers map
deliveryMissMeter: metrics.GetOrRegisterMeter(prefix+"/deliverymiss", nil),
accountLoadMeter: metrics.GetOrRegisterMeter(prefix+"/account/load", nil),
accountDupMeter: metrics.GetOrRegisterMeter(prefix+"/account/dup", nil),
accountSkipMeter: metrics.GetOrRegisterMeter(prefix+"/account/skip", nil),
accountWasteMeter: metrics.GetOrRegisterMeter(prefix+"/account/waste", nil),
storageLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/load", nil),
storageDupMeter: metrics.GetOrRegisterMeter(prefix+"/storage/dup", nil),
storageSkipMeter: metrics.GetOrRegisterMeter(prefix+"/storage/skip", nil),
storageWasteMeter: metrics.GetOrRegisterMeter(prefix+"/storage/waste", nil),
}
return p
}
// close iterates over all the subfetchers, aborts any that were left spinning
// and reports the stats to the metrics subsystem.
func (p *triePrefetcher) close() {
for _, fetcher := range p.fetchers {
fetcher.abort() // safe to do multiple times
if metrics.Enabled {
if fetcher.root == p.root {
p.accountLoadMeter.Mark(int64(len(fetcher.seen)))
p.accountDupMeter.Mark(int64(fetcher.dups))
p.accountSkipMeter.Mark(int64(len(fetcher.tasks)))
for _, key := range fetcher.used {
delete(fetcher.seen, string(key))
}
p.accountWasteMeter.Mark(int64(len(fetcher.seen)))
} else {
p.storageLoadMeter.Mark(int64(len(fetcher.seen)))
p.storageDupMeter.Mark(int64(fetcher.dups))
p.storageSkipMeter.Mark(int64(len(fetcher.tasks)))
for _, key := range fetcher.used {
delete(fetcher.seen, string(key))
}
p.storageWasteMeter.Mark(int64(len(fetcher.seen)))
}
}
}
// Clear out all fetchers (will crash on a second call, deliberate)
p.fetchers = nil
}
// copy creates a deep-but-inactive copy of the trie prefetcher. Any trie data
// already loaded will be copied over, but no goroutines will be started. This
// is mostly used in the miner which creates a copy of it's actively mutated
// state to be sealed while it may further mutate the state.
func (p *triePrefetcher) copy() *triePrefetcher {
copy := &triePrefetcher{
db: p.db,
root: p.root,
fetches: make(map[string]Trie), // Active prefetchers use the fetches map
deliveryMissMeter: p.deliveryMissMeter,
accountLoadMeter: p.accountLoadMeter,
accountDupMeter: p.accountDupMeter,
accountSkipMeter: p.accountSkipMeter,
accountWasteMeter: p.accountWasteMeter,
storageLoadMeter: p.storageLoadMeter,
storageDupMeter: p.storageDupMeter,
storageSkipMeter: p.storageSkipMeter,
storageWasteMeter: p.storageWasteMeter,
}
// If the prefetcher is already a copy, duplicate the data
if p.fetches != nil {
for root, fetch := range p.fetches {
if fetch == nil {
continue
}
copy.fetches[root] = p.db.CopyTrie(fetch)
}
return copy
}
// Otherwise we're copying an active fetcher, retrieve the current states
for id, fetcher := range p.fetchers {
copy.fetches[id] = fetcher.peek()
}
return copy
}
// prefetch schedules a batch of trie items to prefetch.
func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]byte) {
// If the prefetcher is an inactive one, bail out
if p.fetches != nil {
return
}
// Active fetcher, schedule the retrievals
id := p.trieID(owner, root)
fetcher := p.fetchers[id]
if fetcher == nil {
fetcher = newSubfetcher(p.db, p.root, owner, root)
p.fetchers[id] = fetcher
}
fetcher.schedule(keys)
}
// trie returns the trie matching the root hash, or nil if the prefetcher doesn't
// have it.
func (p *triePrefetcher) trie(owner common.Hash, root common.Hash) Trie {
// If the prefetcher is inactive, return from existing deep copies
id := p.trieID(owner, root)
if p.fetches != nil {
trie := p.fetches[id]
if trie == nil {
p.deliveryMissMeter.Mark(1)
return nil
}
return p.db.CopyTrie(trie)
}
// Otherwise the prefetcher is active, bail if no trie was prefetched for this root
fetcher := p.fetchers[id]
if fetcher == nil {
p.deliveryMissMeter.Mark(1)
return nil
}
// Interrupt the prefetcher if it's by any chance still running and return
// a copy of any pre-loaded trie.
fetcher.abort() // safe to do multiple times
trie := fetcher.peek()
if trie == nil {
p.deliveryMissMeter.Mark(1)
return nil
}
return trie
}
// used marks a batch of state items used to allow creating statistics as to
// how useful or wasteful the prefetcher is.
func (p *triePrefetcher) used(owner common.Hash, root common.Hash, used [][]byte) {
if fetcher := p.fetchers[p.trieID(owner, root)]; fetcher != nil {
fetcher.used = used
}
}
// trieID returns an unique trie identifier consists the trie owner and root hash.
func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string {
return string(append(owner.Bytes(), root.Bytes()...))
}
// subfetcher is a trie fetcher goroutine responsible for pulling entries for a
// single trie. It is spawned when a new root is encountered and lives until the
// main prefetcher is paused and either all requested items are processed or if
// the trie being worked on is retrieved from the prefetcher.
type subfetcher struct {
db Database // Database to load trie nodes through
state common.Hash // Root hash of the state to prefetch
owner common.Hash // Owner of the trie, usually account hash
root common.Hash // Root hash of the trie to prefetch
trie Trie // Trie being populated with nodes
tasks [][]byte // Items queued up for retrieval
lock sync.Mutex // Lock protecting the task queue
wake chan struct{} // Wake channel if a new task is scheduled
stop chan struct{} // Channel to interrupt processing
term chan struct{} // Channel to signal interruption
copy chan chan Trie // Channel to request a copy of the current trie
seen map[string]struct{} // Tracks the entries already loaded
dups int // Number of duplicate preload tasks
used [][]byte // Tracks the entries used in the end
}
// newSubfetcher creates a goroutine to prefetch state items belonging to a
// particular root hash.
func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash) *subfetcher {
sf := &subfetcher{
db: db,
state: state,
owner: owner,
root: root,
wake: make(chan struct{}, 1),
stop: make(chan struct{}),
term: make(chan struct{}),
copy: make(chan chan Trie),
seen: make(map[string]struct{}),
}
go sf.loop()
return sf
}
// schedule adds a batch of trie keys to the queue to prefetch.
func (sf *subfetcher) schedule(keys [][]byte) {
// Append the tasks to the current queue
sf.lock.Lock()
sf.tasks = append(sf.tasks, keys...)
sf.lock.Unlock()
// Notify the prefetcher, it's fine if it's already terminated
select {
case sf.wake <- struct{}{}:
default:
}
}
// peek tries to retrieve a deep copy of the fetcher's trie in whatever form it
// is currently.
func (sf *subfetcher) peek() Trie {
ch := make(chan Trie)
select {
case sf.copy <- ch:
// Subfetcher still alive, return copy from it
return <-ch
case <-sf.term:
// Subfetcher already terminated, return a copy directly
if sf.trie == nil {
return nil
}
return sf.db.CopyTrie(sf.trie)
}
}
// abort interrupts the subfetcher immediately. It is safe to call abort multiple
// times but it is not thread safe.
func (sf *subfetcher) abort() {
select {
case <-sf.stop:
default:
close(sf.stop)
}
<-sf.term
}
// loop waits for new tasks to be scheduled and keeps loading them until it runs
// out of tasks or its underlying trie is retrieved for committing.
func (sf *subfetcher) loop() {
// No matter how the loop stops, signal anyone waiting that it's terminated
defer close(sf.term)
// Start by opening the trie and stop processing if it fails
if sf.owner == (common.Hash{}) {
trie, err := sf.db.OpenTrie(sf.root)
if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return
}
sf.trie = trie
} else {
trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root)
if err != nil {
log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
return
}
sf.trie = trie
}
// Trie opened successfully, keep prefetching items
for {
select {
case <-sf.wake:
// Subfetcher was woken up, retrieve any tasks to avoid spinning the lock
sf.lock.Lock()
tasks := sf.tasks
sf.tasks = nil
sf.lock.Unlock()
// Prefetch any tasks until the loop is interrupted
for i, task := range tasks {
select {
case <-sf.stop:
// If termination is requested, add any leftover back and return
sf.lock.Lock()
sf.tasks = append(sf.tasks, tasks[i:]...)
sf.lock.Unlock()
return
case ch := <-sf.copy:
// Somebody wants a copy of the current trie, grant them
ch <- sf.db.CopyTrie(sf.trie)
default:
// No termination request yet, prefetch the next entry
if _, ok := sf.seen[string(task)]; ok {
sf.dups++
} else {
sf.trie.TryGet(task)
sf.seen[string(task)] = struct{}{}
}
}
}
case ch := <-sf.copy:
// Somebody wants a copy of the current trie, grant them
ch <- sf.db.CopyTrie(sf.trie)
case <-sf.stop:
// Termination is requested, abort and leave remaining tasks
return
}
}
}

View File

@ -0,0 +1,110 @@
// Copyright 2021 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"math/big"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
)
func filledStateDB() *StateDB {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
// Create an account and check if the retrieved balance is correct
addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
skey := common.HexToHash("aaa")
sval := common.HexToHash("bbb")
state.SetBalance(addr, big.NewInt(42)) // Change the account trie
state.SetCode(addr, []byte("hello")) // Change an external metadata
state.SetState(addr, skey, sval) // Change the storage trie
for i := 0; i < 100; i++ {
sk := common.BigToHash(big.NewInt(int64(i)))
state.SetState(addr, sk, sk) // Change the storage trie
}
return state
}
func TestCopyAndClose(t *testing.T) {
db := filledStateDB()
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
time.Sleep(1 * time.Second)
a := prefetcher.trie(common.Hash{}, db.originalRoot)
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
b := prefetcher.trie(common.Hash{}, db.originalRoot)
cpy := prefetcher.copy()
cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
c := cpy.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
cpy2 := cpy.copy()
cpy2.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
d := cpy2.trie(common.Hash{}, db.originalRoot)
cpy.close()
cpy2.close()
if a.Hash() != b.Hash() || a.Hash() != c.Hash() || a.Hash() != d.Hash() {
t.Fatalf("Invalid trie, hashes should be equal: %v %v %v %v", a.Hash(), b.Hash(), c.Hash(), d.Hash())
}
}
func TestUseAfterClose(t *testing.T) {
db := filledStateDB()
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
a := prefetcher.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
b := prefetcher.trie(common.Hash{}, db.originalRoot)
if a == nil {
t.Fatal("Prefetching before close should not return nil")
}
if b != nil {
t.Fatal("Trie after close should return nil")
}
}
func TestCopyClose(t *testing.T) {
db := filledStateDB()
prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
skey := common.HexToHash("aaa")
prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
cpy := prefetcher.copy()
a := prefetcher.trie(common.Hash{}, db.originalRoot)
b := cpy.trie(common.Hash{}, db.originalRoot)
prefetcher.close()
c := prefetcher.trie(common.Hash{}, db.originalRoot)
d := cpy.trie(common.Hash{}, db.originalRoot)
if a == nil {
t.Fatal("Prefetching before close should not return nil")
}
if b == nil {
t.Fatal("Copy trie should return nil")
}
if c != nil {
t.Fatal("Trie after close should return nil")
}
if d == nil {
t.Fatal("Copy trie should not return nil")
}
}

View File

@ -0,0 +1,196 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// leaf represents a trie leaf node
type leaf struct {
blob []byte // raw blob of leaf
parent common.Hash // the hash of parent node
}
// committer is the tool used for the trie Commit operation. The committer will
// capture all dirty nodes during the commit process and keep them cached in
// insertion order.
type committer struct {
nodes *NodeSet
collectLeaf bool
}
// newCommitter creates a new committer or picks one from the pool.
func newCommitter(nodeset *NodeSet, collectLeaf bool) *committer {
return &committer{
nodes: nodeset,
collectLeaf: collectLeaf,
}
}
// Commit collapses a node down into a hash node.
func (c *committer) Commit(n node) hashNode {
return c.commit(nil, n).(hashNode)
}
// commit collapses a node down into a hash node and returns it.
func (c *committer) commit(path []byte, n node) node {
// if this path is clean, use available cached data
hash, dirty := n.cache()
if hash != nil && !dirty {
return hash
}
// Commit children, then parent, and remove the dirty flag.
switch cn := n.(type) {
case *shortNode:
// Commit child
collapsed := cn.copy()
// If the child is fullNode, recursively commit,
// otherwise it can only be hashNode or valueNode.
if _, ok := cn.Val.(*fullNode); ok {
collapsed.Val = c.commit(append(path, cn.Key...), cn.Val)
}
// The key needs to be copied, since we're adding it to the
// modified nodeset.
collapsed.Key = hexToCompact(cn.Key)
hashedNode := c.store(path, collapsed)
if hn, ok := hashedNode.(hashNode); ok {
return hn
}
return collapsed
case *fullNode:
hashedKids := c.commitChildren(path, cn)
collapsed := cn.copy()
collapsed.Children = hashedKids
hashedNode := c.store(path, collapsed)
if hn, ok := hashedNode.(hashNode); ok {
return hn
}
return collapsed
case hashNode:
return cn
default:
// nil, valuenode shouldn't be committed
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
// commitChildren commits the children of the given fullnode
func (c *committer) commitChildren(path []byte, n *fullNode) [17]node {
var children [17]node
for i := 0; i < 16; i++ {
child := n.Children[i]
if child == nil {
continue
}
// If it's the hashed child, save the hash value directly.
// Note: it's impossible that the child in range [0, 15]
// is a valueNode.
if hn, ok := child.(hashNode); ok {
children[i] = hn
continue
}
// Commit the child recursively and store the "hashed" value.
// Note the returned node can be some embedded nodes, so it's
// possible the type is not hashNode.
children[i] = c.commit(append(path, byte(i)), child)
}
// For the 17th child, it's possible the type is valuenode.
if n.Children[16] != nil {
children[16] = n.Children[16]
}
return children
}
// store hashes the node n and adds it to the modified nodeset. If leaf collection
// is enabled, leaf nodes will be tracked in the modified nodeset as well.
func (c *committer) store(path []byte, n node) node {
// Larger nodes are replaced by their hash and stored in the database.
var hash, _ = n.cache()
// This was not generated - must be a small node stored in the parent.
// In theory, we should check if the node is leaf here (embedded node
// usually is leaf node). But small value (less than 32bytes) is not
// our target (leaves in account trie only).
if hash == nil {
// The node is embedded in its parent, in other words, this node
// will not be stored in the database independently, mark it as
// deleted only if the node was existent in database before.
if _, ok := c.nodes.accessList[string(path)]; ok {
c.nodes.markDeleted(path)
}
return n
}
// We have the hash already, estimate the RLP encoding-size of the node.
// The size is used for mem tracking, does not need to be exact
var (
size = estimateSize(n)
nhash = common.BytesToHash(hash)
mnode = &memoryNode{
hash: nhash,
node: simplifyNode(n),
size: uint16(size),
}
)
// Collect the dirty node to nodeset for return.
c.nodes.markUpdated(path, mnode)
// Collect the corresponding leaf node if it's required. We don't check
// full node since it's impossible to store value in fullNode. The key
// length of leaves should be exactly same.
if c.collectLeaf {
if sn, ok := n.(*shortNode); ok {
if val, ok := sn.Val.(valueNode); ok {
c.nodes.addLeaf(&leaf{blob: val, parent: nhash})
}
}
}
return hash
}
// estimateSize estimates the size of an rlp-encoded node, without actually
// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie
// with 1000 leaves, the only errors above 1% are on small shortnodes, where this
// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35)
func estimateSize(n node) int {
switch n := n.(type) {
case *shortNode:
// A short node contains a compacted key, and a value.
return 3 + len(n.Key) + estimateSize(n.Val)
case *fullNode:
// A full node contains up to 16 hashes (some nils), and a key
s := 3
for i := 0; i < 16; i++ {
if child := n.Children[i]; child != nil {
s += estimateSize(child)
} else {
s++
}
}
return s
case valueNode:
return 1 + len(n)
case hashNode:
return 1 + len(n)
default:
panic(fmt.Sprintf("node type %T", n))
}
}

View File

@ -18,25 +18,50 @@ package trie
import ( import (
"errors" "errors"
"runtime"
"sync"
"time"
"github.com/VictoriaMetrics/fastcache" "github.com/VictoriaMetrics/fastcache"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
"github.com/ipfs/go-cid" log "github.com/sirupsen/logrus"
) )
type CidKey = cid.Cid // Database is an intermediate write layer between the trie data structures and
// the disk database. The aim is to accumulate trie writes in-memory and only
func isEmpty(key CidKey) bool { // periodically flush a couple tries to disk, garbage collecting the remainder.
return len(key.KeyString()) == 0 //
} // Note, the trie Database is **not** thread safe in its mutations, but it **is**
// thread safe in providing individual, independent node access. The rationale
// Database is an intermediate read-only layer between the trie data structures and // behind this split design is to provide read access to RPC handlers and sync
// the disk database. This trie Database is thread safe in providing individual, // servers even while the trie is executing expensive garbage collection.
// independent node access.
type Database struct { type Database struct {
diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes diskdb ethdb.Database // Persistent storage for matured trie nodes
cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
oldest common.Hash // Oldest tracked node, flush-list head
newest common.Hash // Newest tracked node, flush-list tail
gctime time.Duration // Time spent on garbage collection since last commit
gcnodes uint64 // Nodes garbage collected since last commit
gcsize common.StorageSize // Data storage garbage collected since last commit
flushtime time.Duration // Time spent on data flushing since last commit
flushnodes uint64 // Nodes flushed since last commit
flushsize common.StorageSize // Data storage flushed since last commit
dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata)
childrenSize common.StorageSize // Storage size of the external children tracking
preimages *preimageStore // The store for caching preimages
lock sync.RWMutex
} }
// Config defines all necessary options for database. // Config defines all necessary options for database.
@ -46,14 +71,14 @@ type Config = trie.Config
// NewDatabase creates a new trie database to store ephemeral trie content before // NewDatabase creates a new trie database to store ephemeral trie content before
// its written out to disk or garbage collected. No read cache is created, so all // its written out to disk or garbage collected. No read cache is created, so all
// data retrievals will hit the underlying disk database. // data retrievals will hit the underlying disk database.
func NewDatabase(diskdb ethdb.KeyValueStore) *Database { func NewDatabase(diskdb ethdb.Database) *Database {
return NewDatabaseWithConfig(diskdb, nil) return NewDatabaseWithConfig(diskdb, nil)
} }
// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content // NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
// before it's written out to disk or garbage collected. It also acts as a read cache // before its written out to disk or garbage collected. It also acts as a read cache
// for nodes loaded from disk. // for nodes loaded from disk.
func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database { func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database {
var cleans *fastcache.Cache var cleans *fastcache.Cache
if config != nil && config.Cache > 0 { if config != nil && config.Cache > 0 {
if config.Journal == "" { if config.Journal == "" {
@ -62,44 +87,354 @@ func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database
cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024) cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024)
} }
} }
var preimage *preimageStore
if config != nil && config.Preimages {
preimage = newPreimageStore(diskdb)
}
db := &Database{ db := &Database{
diskdb: diskdb, diskdb: diskdb,
cleans: cleans, cleans: cleans,
dirties: map[common.Hash]*cachedNode{{}: {
children: make(map[common.Hash]uint16),
}},
preimages: preimage,
} }
return db return db
} }
// DiskDB retrieves the persistent storage backing the trie database. // insert inserts a simplified trie node into the memory database.
func (db *Database) DiskDB() ethdb.KeyValueStore { // All nodes inserted by this function will be reference tracked
return db.diskdb // and in theory should only used for **trie nodes** insertion.
func (db *Database) insert(hash common.Hash, size int, node node) {
// If the node's already cached, skip
if _, ok := db.dirties[hash]; ok {
return
}
memcacheDirtyWriteMeter.Mark(int64(size))
// Create the cached entry for this node
entry := &cachedNode{
node: node,
size: uint16(size),
flushPrev: db.newest,
}
entry.forChilds(func(child common.Hash) {
if c := db.dirties[child]; c != nil {
c.parents++
}
})
db.dirties[hash] = entry
// Update the flush-list endpoints
if db.oldest == (common.Hash{}) {
db.oldest, db.newest = hash, hash
} else {
db.dirties[db.newest].flushNext, db.newest = hash, hash
}
db.dirtiesSize += common.StorageSize(common.HashLength + entry.size)
} }
// Node retrieves an encoded trie node by CID. If it cannot be found // Node retrieves an encoded cached trie node from memory. If it cannot be found
// cached in memory, it queries the persistent database. // cached, the method queries the persistent database for the content.
func (db *Database) Node(key CidKey) ([]byte, error) { func (db *Database) Node(hash common.Hash, codec uint64) ([]byte, error) {
// It doesn't make sense to retrieve the metaroot // It doesn't make sense to retrieve the metaroot
if isEmpty(key) { if hash == (common.Hash{}) {
return nil, errors.New("not found") return nil, errors.New("not found")
} }
cidbytes := key.Bytes()
// Retrieve the node from the clean cache if available // Retrieve the node from the clean cache if available
if db.cleans != nil { if db.cleans != nil {
if enc := db.cleans.Get(nil, cidbytes); enc != nil { if enc := db.cleans.Get(nil, hash[:]); enc != nil {
memcacheCleanHitMeter.Mark(1)
memcacheCleanReadMeter.Mark(int64(len(enc)))
return enc, nil return enc, nil
} }
} }
// Retrieve the node from the dirty cache if available
db.lock.RLock()
dirty := db.dirties[hash]
db.lock.RUnlock()
if dirty != nil {
memcacheDirtyHitMeter.Mark(1)
memcacheDirtyReadMeter.Mark(int64(dirty.size))
return dirty.rlp(), nil
}
memcacheDirtyMissMeter.Mark(1)
// Content unavailable in memory, attempt to retrieve from disk // Content unavailable in memory, attempt to retrieve from disk
enc, err := db.diskdb.Get(cidbytes) cid, err := internal.Keccak256ToCid(codec, hash[:])
if err != nil {
return nil, err
}
enc, err := db.diskdb.Get(cid.Bytes())
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(enc) != 0 { if len(enc) != 0 {
if db.cleans != nil { if db.cleans != nil {
db.cleans.Set(cidbytes, enc) db.cleans.Set(hash[:], enc)
memcacheCleanMissMeter.Mark(1)
memcacheCleanWriteMeter.Mark(int64(len(enc)))
} }
return enc, nil return enc, nil
} }
return nil, errors.New("not found") return nil, errors.New("not found")
} }
// Nodes retrieves the hashes of all the nodes cached within the memory database.
// This method is extremely expensive and should only be used to validate internal
// states in test code.
func (db *Database) Nodes() []common.Hash {
db.lock.RLock()
defer db.lock.RUnlock()
var hashes = make([]common.Hash, 0, len(db.dirties))
for hash := range db.dirties {
if hash != (common.Hash{}) { // Special case for "root" references/nodes
hashes = append(hashes, hash)
}
}
return hashes
}
// Reference adds a new reference from a parent node to a child node.
// This function is used to add reference between internal trie node
// and external node(e.g. storage trie root), all internal trie nodes
// are referenced together by database itself.
func (db *Database) Reference(child common.Hash, parent common.Hash) {
db.lock.Lock()
defer db.lock.Unlock()
db.reference(child, parent)
}
// reference is the private locked version of Reference.
func (db *Database) reference(child common.Hash, parent common.Hash) {
// If the node does not exist, it's a node pulled from disk, skip
node, ok := db.dirties[child]
if !ok {
return
}
// If the reference already exists, only duplicate for roots
if db.dirties[parent].children == nil {
db.dirties[parent].children = make(map[common.Hash]uint16)
db.childrenSize += cachedNodeChildrenSize
} else if _, ok = db.dirties[parent].children[child]; ok && parent != (common.Hash{}) {
return
}
node.parents++
db.dirties[parent].children[child]++
if db.dirties[parent].children[child] == 1 {
db.childrenSize += common.HashLength + 2 // uint16 counter
}
}
// Dereference removes an existing reference from a root node.
func (db *Database) Dereference(root common.Hash) {
// Sanity check to ensure that the meta-root is not removed
if root == (common.Hash{}) {
log.Error("Attempted to dereference the trie cache meta root")
return
}
db.lock.Lock()
defer db.lock.Unlock()
nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
db.dereference(root, common.Hash{})
db.gcnodes += uint64(nodes - len(db.dirties))
db.gcsize += storage - db.dirtiesSize
db.gctime += time.Since(start)
memcacheGCTimeTimer.Update(time.Since(start))
memcacheGCSizeMeter.Mark(int64(storage - db.dirtiesSize))
memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties)))
log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
}
// dereference is the private locked version of Dereference.
func (db *Database) dereference(child common.Hash, parent common.Hash) {
// Dereference the parent-child
node := db.dirties[parent]
if node.children != nil && node.children[child] > 0 {
node.children[child]--
if node.children[child] == 0 {
delete(node.children, child)
db.childrenSize -= (common.HashLength + 2) // uint16 counter
}
}
// If the child does not exist, it's a previously committed node.
node, ok := db.dirties[child]
if !ok {
return
}
// If there are no more references to the child, delete it and cascade
if node.parents > 0 {
// This is a special cornercase where a node loaded from disk (i.e. not in the
// memcache any more) gets reinjected as a new node (short node split into full,
// then reverted into short), causing a cached node to have no parents. That is
// no problem in itself, but don't make maxint parents out of it.
node.parents--
}
if node.parents == 0 {
// Remove the node from the flush-list
switch child {
case db.oldest:
db.oldest = node.flushNext
db.dirties[node.flushNext].flushPrev = common.Hash{}
case db.newest:
db.newest = node.flushPrev
db.dirties[node.flushPrev].flushNext = common.Hash{}
default:
db.dirties[node.flushPrev].flushNext = node.flushNext
db.dirties[node.flushNext].flushPrev = node.flushPrev
}
// Dereference all children and delete the node
node.forChilds(func(hash common.Hash) {
db.dereference(hash, child)
})
delete(db.dirties, child)
db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size))
if node.children != nil {
db.childrenSize -= cachedNodeChildrenSize
}
}
}
// Update inserts the dirty nodes in provided nodeset into database and
// link the account trie with multiple storage tries if necessary.
func (db *Database) Update(nodes *MergedNodeSet) error {
db.lock.Lock()
defer db.lock.Unlock()
// Insert dirty nodes into the database. In the same tree, it must be
// ensured that children are inserted first, then parent so that children
// can be linked with their parent correctly.
//
// Note, the storage tries must be flushed before the account trie to
// retain the invariant that children go into the dirty cache first.
var order []common.Hash
for owner := range nodes.sets {
if owner == (common.Hash{}) {
continue
}
order = append(order, owner)
}
if _, ok := nodes.sets[common.Hash{}]; ok {
order = append(order, common.Hash{})
}
for _, owner := range order {
subset := nodes.sets[owner]
subset.forEachWithOrder(func(path string, n *memoryNode) {
if n.isDeleted() {
return // ignore deletion
}
db.insert(n.hash, int(n.size), n.node)
})
}
// Link up the account trie and storage trie if the node points
// to an account trie leaf.
if set, present := nodes.sets[common.Hash{}]; present {
for _, n := range set.leaves {
var account types.StateAccount
if err := rlp.DecodeBytes(n.blob, &account); err != nil {
return err
}
if account.Root != types.EmptyRootHash {
db.reference(account.Root, n.parent)
}
}
}
return nil
}
// Size returns the current storage size of the memory cache in front of the
// persistent database layer.
func (db *Database) Size() (common.StorageSize, common.StorageSize) {
db.lock.RLock()
defer db.lock.RUnlock()
// db.dirtiesSize only contains the useful data in the cache, but when reporting
// the total memory consumption, the maintenance metadata is also needed to be
// counted.
var metadataSize = common.StorageSize((len(db.dirties) - 1) * cachedNodeSize)
var metarootRefs = common.StorageSize(len(db.dirties[common.Hash{}].children) * (common.HashLength + 2))
var preimageSize common.StorageSize
if db.preimages != nil {
preimageSize = db.preimages.size()
}
return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, preimageSize
}
// GetReader retrieves a node reader belonging to the given state root.
func (db *Database) GetReader(root common.Hash, codec uint64) Reader {
return &hashReader{db: db, codec: codec}
}
// hashReader is reader of hashDatabase which implements the Reader interface.
type hashReader struct {
db *Database
codec uint64
}
// Node retrieves the trie node with the given node hash.
func (reader *hashReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
blob, err := reader.NodeBlob(owner, path, hash)
if err != nil {
return nil, err
}
return decodeNodeUnsafe(hash[:], blob)
}
// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
func (reader *hashReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
return reader.db.Node(hash, reader.codec)
}
// saveCache saves clean state cache to given directory path
// using specified CPU cores.
func (db *Database) saveCache(dir string, threads int) error {
if db.cleans == nil {
return nil
}
log.Info("Writing clean trie cache to disk", "path", dir, "threads", threads)
start := time.Now()
err := db.cleans.SaveToFileConcurrent(dir, threads)
if err != nil {
log.Error("Failed to persist clean trie cache", "error", err)
return err
}
log.Info("Persisted the clean trie cache", "path", dir, "elapsed", common.PrettyDuration(time.Since(start)))
return nil
}
// SaveCache atomically saves fast cache data to the given dir using all
// available CPU cores.
func (db *Database) SaveCache(dir string) error {
return db.saveCache(dir, runtime.GOMAXPROCS(0))
}
// SaveCachePeriodically atomically saves fast cache data to the given dir with
// the specified interval. All dump operation will only use a single CPU core.
func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
db.saveCache(dir, 1)
case <-stopCh:
return
}
}
}
// Scheme returns the node scheme used in the database.
func (db *Database) Scheme() string {
return rawdb.HashScheme
}

View File

@ -0,0 +1,27 @@
package trie
import "github.com/ethereum/go-ethereum/metrics"
var (
memcacheCleanHitMeter = metrics.NewRegisteredMeter("trie/memcache/clean/hit", nil)
memcacheCleanMissMeter = metrics.NewRegisteredMeter("trie/memcache/clean/miss", nil)
memcacheCleanReadMeter = metrics.NewRegisteredMeter("trie/memcache/clean/read", nil)
memcacheCleanWriteMeter = metrics.NewRegisteredMeter("trie/memcache/clean/write", nil)
memcacheDirtyHitMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/hit", nil)
memcacheDirtyMissMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/miss", nil)
memcacheDirtyReadMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/read", nil)
memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/write", nil)
memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/flush/time", nil)
memcacheFlushNodesMeter = metrics.NewRegisteredMeter("trie/memcache/flush/nodes", nil)
memcacheFlushSizeMeter = metrics.NewRegisteredMeter("trie/memcache/flush/size", nil)
memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/gc/time", nil)
memcacheGCNodesMeter = metrics.NewRegisteredMeter("trie/memcache/gc/nodes", nil)
memcacheGCSizeMeter = metrics.NewRegisteredMeter("trie/memcache/gc/size", nil)
memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/commit/time", nil)
memcacheCommitNodesMeter = metrics.NewRegisteredMeter("trie/memcache/commit/nodes", nil)
memcacheCommitSizeMeter = metrics.NewRegisteredMeter("trie/memcache/commit/size", nil)
)

View File

@ -0,0 +1,183 @@
package trie
import (
"fmt"
"io"
"reflect"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
)
// rawNode is a simple binary blob used to differentiate between collapsed trie
// nodes and already encoded RLP binary blobs (while at the same time store them
// in the same cache fields).
type rawNode []byte
func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
func (n rawNode) EncodeRLP(w io.Writer) error {
_, err := w.Write(n)
return err
}
// rawFullNode represents only the useful data content of a full node, with the
// caches and flags stripped out to minimize its data storage. This type honors
// the same RLP encoding as the original parent.
type rawFullNode [17]node
func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") }
func (n rawFullNode) EncodeRLP(w io.Writer) error {
eb := rlp.NewEncoderBuffer(w)
n.encode(eb)
return eb.Flush()
}
// rawShortNode represents only the useful data content of a short node, with the
// caches and flags stripped out to minimize its data storage. This type honors
// the same RLP encoding as the original parent.
type rawShortNode struct {
Key []byte
Val node
}
func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") }
// cachedNode is all the information we know about a single cached trie node
// in the memory database write layer.
type cachedNode struct {
node node // Cached collapsed trie node, or raw rlp data
size uint16 // Byte size of the useful cached data
parents uint32 // Number of live nodes referencing this one
children map[common.Hash]uint16 // External children referenced by this node
flushPrev common.Hash // Previous node in the flush-list
flushNext common.Hash // Next node in the flush-list
}
// cachedNodeSize is the raw size of a cachedNode data structure without any
// node data included. It's an approximate size, but should be a lot better
// than not counting them.
var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
// cachedNodeChildrenSize is the raw size of an initialized but empty external
// reference map.
const cachedNodeChildrenSize = 48
// rlp returns the raw rlp encoded blob of the cached trie node, either directly
// from the cache, or by regenerating it from the collapsed node.
func (n *cachedNode) rlp() []byte {
if node, ok := n.node.(rawNode); ok {
return node
}
return nodeToBytes(n.node)
}
// obj returns the decoded and expanded trie node, either directly from the cache,
// or by regenerating it from the rlp encoded blob.
func (n *cachedNode) obj(hash common.Hash) node {
if node, ok := n.node.(rawNode); ok {
// The raw-blob format nodes are loaded either from the
// clean cache or the database, they are all in their own
// copy and safe to use unsafe decoder.
return mustDecodeNodeUnsafe(hash[:], node)
}
return expandNode(hash[:], n.node)
}
// forChilds invokes the callback for all the tracked children of this node,
// both the implicit ones from inside the node as well as the explicit ones
// from outside the node.
func (n *cachedNode) forChilds(onChild func(hash common.Hash)) {
for child := range n.children {
onChild(child)
}
if _, ok := n.node.(rawNode); !ok {
forGatherChildren(n.node, onChild)
}
}
// forGatherChildren traverses the node hierarchy of a collapsed storage node and
// invokes the callback for all the hashnode children.
func forGatherChildren(n node, onChild func(hash common.Hash)) {
switch n := n.(type) {
case *rawShortNode:
forGatherChildren(n.Val, onChild)
case rawFullNode:
for i := 0; i < 16; i++ {
forGatherChildren(n[i], onChild)
}
case hashNode:
onChild(common.BytesToHash(n))
case valueNode, nil, rawNode:
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}
// simplifyNode traverses the hierarchy of an expanded memory node and discards
// all the internal caches, returning a node that only contains the raw data.
func simplifyNode(n node) node {
switch n := n.(type) {
case *shortNode:
// Short nodes discard the flags and cascade
return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)}
case *fullNode:
// Full nodes discard the flags and cascade
node := rawFullNode(n.Children)
for i := 0; i < len(node); i++ {
if node[i] != nil {
node[i] = simplifyNode(node[i])
}
}
return node
case valueNode, hashNode, rawNode:
return n
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}
// expandNode traverses the node hierarchy of a collapsed storage node and converts
// all fields and keys into expanded memory form.
func expandNode(hash hashNode, n node) node {
switch n := n.(type) {
case *rawShortNode:
// Short nodes need key and child expansion
return &shortNode{
Key: compactToHex(n.Key),
Val: expandNode(nil, n.Val),
flags: nodeFlag{
hash: hash,
},
}
case rawFullNode:
// Full nodes need child expansion
node := &fullNode{
flags: nodeFlag{
hash: hash,
},
}
for i := 0; i < len(node.Children); i++ {
if n[i] != nil {
node.Children[i] = expandNode(nil, n[i])
}
}
return node
case valueNode, hashNode:
return n
default:
panic(fmt.Sprintf("unknown node type: %T", n))
}
}

View File

@ -14,21 +14,20 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"testing" "testing"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
// Tests that the trie database returns a missing trie node error if attempting // Tests that the trie database returns a missing trie node error if attempting
// to retrieve the meta root. // to retrieve the meta root.
func TestDatabaseMetarootFetch(t *testing.T) { func TestDatabaseMetarootFetch(t *testing.T) {
db := trie.NewDatabase(memorydb.New()) db := NewDatabase(rawdb.NewMemoryDatabase())
if _, err := db.Node(trie.CidKey{}); err == nil { if _, err := db.Node(common.Hash{}, StateTrieCodec); err == nil {
t.Fatalf("metaroot retrieval succeeded") t.Fatalf("metaroot retrieval succeeded")
} }
} }

View File

@ -16,8 +16,81 @@
package trie package trie
// Trie keys are dealt with in three distinct encodings:
//
// KEYBYTES encoding contains the actual key and nothing else. This encoding is the
// input to most API functions.
//
// HEX encoding contains one byte for each nibble of the key and an optional trailing
// 'terminator' byte of value 0x10 which indicates whether or not the node at the key
// contains a value. Hex key encoding is used for nodes loaded in memory because it's
// convenient to access.
//
// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix
// encoding" there) and contains the bytes of the key and a flag. The high nibble of the
// first byte contains the flag; the lowest bit encoding the oddness of the length and
// the second-lowest encoding whether the node at the key is a value node. The low nibble
// of the first byte is zero in the case of an even number of nibbles and the first nibble
// in the case of an odd number. All remaining nibbles (now an even number) fit properly
// into the remaining bytes. Compact encoding is used for nodes stored on disk.
// HexToCompact converts a hex path to the compact encoded format
func HexToCompact(hex []byte) []byte {
return hexToCompact(hex)
}
func hexToCompact(hex []byte) []byte {
terminator := byte(0)
if hasTerm(hex) {
terminator = 1
hex = hex[:len(hex)-1]
}
buf := make([]byte, len(hex)/2+1)
buf[0] = terminator << 5 // the flag byte
if len(hex)&1 == 1 {
buf[0] |= 1 << 4 // odd flag
buf[0] |= hex[0] // first nibble is contained in the first byte
hex = hex[1:]
}
decodeNibbles(hex, buf[1:])
return buf
}
// hexToCompactInPlace places the compact key in input buffer, returning the length
// needed for the representation
func hexToCompactInPlace(hex []byte) int {
var (
hexLen = len(hex) // length of the hex input
firstByte = byte(0)
)
// Check if we have a terminator there
if hexLen > 0 && hex[hexLen-1] == 16 {
firstByte = 1 << 5
hexLen-- // last part was the terminator, ignore that
}
var (
binLen = hexLen/2 + 1
ni = 0 // index in hex
bi = 1 // index in bin (compact)
)
if hexLen&1 == 1 {
firstByte |= 1 << 4 // odd flag
firstByte |= hex[0] // first nibble is contained in the first byte
ni++
}
for ; ni < hexLen; bi, ni = bi+1, ni+2 {
hex[bi] = hex[ni]<<4 | hex[ni+1]
}
hex[0] = firstByte
return binLen
}
// CompactToHex converts a compact encoded path to hex format // CompactToHex converts a compact encoded path to hex format
func CompactToHex(compact []byte) []byte { func CompactToHex(compact []byte) []byte {
return compactToHex(compact)
}
func compactToHex(compact []byte) []byte {
if len(compact) == 0 { if len(compact) == 0 {
return compact return compact
} }
@ -62,6 +135,20 @@ func decodeNibbles(nibbles []byte, bytes []byte) {
} }
} }
// prefixLen returns the length of the common prefix of a and b.
func prefixLen(a, b []byte) int {
var i, length = 0, len(a)
if len(b) < length {
length = len(b)
}
for ; i < length; i++ {
if a[i] != b[i] {
break
}
}
return i
}
// hasTerm returns whether a hex key has the terminator flag. // hasTerm returns whether a hex key has the terminator flag.
func hasTerm(s []byte) bool { func hasTerm(s []byte) bool {
return len(s) > 0 && s[len(s)-1] == 16 return len(s) > 0 && s[len(s)-1] == 16

View File

@ -0,0 +1,141 @@
// Copyright 2014 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"bytes"
crand "crypto/rand"
"encoding/hex"
"math/rand"
"testing"
)
func TestHexCompact(t *testing.T) {
tests := []struct{ hex, compact []byte }{
// empty keys, with and without terminator.
{hex: []byte{}, compact: []byte{0x00}},
{hex: []byte{16}, compact: []byte{0x20}},
// odd length, no terminator
{hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}},
// even length, no terminator
{hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}},
// odd length, terminator
{hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}},
// even length, terminator
{hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}},
}
for _, test := range tests {
if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) {
t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact)
}
if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) {
t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex)
}
}
}
func TestHexKeybytes(t *testing.T) {
tests := []struct{ key, hexIn, hexOut []byte }{
{key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}},
{key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}},
{
key: []byte{0x12, 0x34, 0x56},
hexIn: []byte{1, 2, 3, 4, 5, 6, 16},
hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
},
{
key: []byte{0x12, 0x34, 0x5},
hexIn: []byte{1, 2, 3, 4, 0, 5, 16},
hexOut: []byte{1, 2, 3, 4, 0, 5, 16},
},
{
key: []byte{0x12, 0x34, 0x56},
hexIn: []byte{1, 2, 3, 4, 5, 6},
hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
},
}
for _, test := range tests {
if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
}
if k := hexToKeyBytes(test.hexIn); !bytes.Equal(k, test.key) {
t.Errorf("hexToKeyBytes(%x) -> %x, want %x", test.hexIn, k, test.key)
}
}
}
func TestHexToCompactInPlace(t *testing.T) {
for i, key := range []string{
"00",
"060a040c0f000a090b040803010801010900080d090a0a0d0903000b10",
"10",
} {
hexBytes, _ := hex.DecodeString(key)
exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) {
t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp)
}
}
}
func TestHexToCompactInPlaceRandom(t *testing.T) {
for i := 0; i < 10000; i++ {
l := rand.Intn(128)
key := make([]byte, l)
crand.Read(key)
hexBytes := keybytesToHex(key)
hexOrig := []byte(string(hexBytes))
exp := hexToCompact(hexBytes)
sz := hexToCompactInPlace(hexBytes)
got := hexBytes[:sz]
if !bytes.Equal(exp, got) {
t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n",
key, hexOrig, got, exp)
}
}
}
func BenchmarkHexToCompact(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
hexToCompact(testBytes)
}
}
func BenchmarkCompactToHex(b *testing.B) {
testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
for i := 0; i < b.N; i++ {
compactToHex(testBytes)
}
}
func BenchmarkKeybytesToHex(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
keybytesToHex(testBytes)
}
}
func BenchmarkHexToKeybytes(b *testing.B) {
testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
for i := 0; i < b.N; i++ {
hexToKeyBytes(testBytes)
}
}

View File

@ -27,7 +27,7 @@ import (
// information necessary for retrieving the missing node. // information necessary for retrieving the missing node.
type MissingNodeError struct { type MissingNodeError struct {
Owner common.Hash // owner of the trie if it's 2-layered trie Owner common.Hash // owner of the trie if it's 2-layered trie
NodeHash []byte // hash of the missing node NodeHash common.Hash // hash of the missing node
Path []byte // hex-encoded path to the missing node Path []byte // hex-encoded path to the missing node
err error // concrete error for missing trie node err error // concrete error for missing trie node
} }

View File

@ -21,7 +21,6 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
) )
@ -99,7 +98,7 @@ func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNo
// Previously, we did copy this one. We don't seem to need to actually // Previously, we did copy this one. We don't seem to need to actually
// do that, since we don't overwrite/reuse keys // do that, since we don't overwrite/reuse keys
//cached.Key = common.CopyBytes(n.Key) //cached.Key = common.CopyBytes(n.Key)
collapsed.Key = trie.HexToCompact(n.Key) collapsed.Key = hexToCompact(n.Key)
// Unless the child is a valuenode or hashnode, hash it // Unless the child is a valuenode or hashnode, hash it
switch n.Val.(type) { switch n.Val.(type) {
case *fullNode, *shortNode: case *fullNode, *shortNode:

View File

@ -18,14 +18,26 @@ package trie
import ( import (
"bytes" "bytes"
"container/heap"
"errors" "errors"
"time"
"github.com/ethereum/go-ethereum/statediff/indexer/database/metrics"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/core/types"
gethtrie "github.com/ethereum/go-ethereum/trie"
) )
// NodeIterator is a re-export of the go-ethereum interface // NodeIterator is an iterator to traverse the trie pre-order.
type NodeIterator = trie.NodeIterator type NodeIterator = gethtrie.NodeIterator
// NodeResolver is used for looking up trie nodes before reaching into the real
// persistent layer. This is not mandatory, rather is an optimization for cases
// where trie nodes can be recovered from some external mechanism without reading
// from disk. In those cases, this resolver allows short circuiting accesses and
// returning them from memory.
type NodeResolver = gethtrie.NodeResolver
// Iterator is a key-value trie iterator that traverses a Trie. // Iterator is a key-value trie iterator that traverses a Trie.
type Iterator struct { type Iterator struct {
@ -82,7 +94,7 @@ type nodeIterator struct {
path []byte // Path to the current node path []byte // Path to the current node
err error // Failure set in case of an internal error in the iterator err error // Failure set in case of an internal error in the iterator
resolver trie.NodeResolver // Optional intermediate resolver above the disk layer resolver NodeResolver // optional node resolver for avoiding disk hits
} }
// errIteratorEnd is stored in nodeIterator.err when iteration is done. // errIteratorEnd is stored in nodeIterator.err when iteration is done.
@ -99,7 +111,7 @@ func (e seekError) Error() string {
} }
func newNodeIterator(trie *Trie, start []byte) NodeIterator { func newNodeIterator(trie *Trie, start []byte) NodeIterator {
if trie.Hash() == emptyRoot { if trie.Hash() == types.EmptyRootHash {
return &nodeIterator{ return &nodeIterator{
trie: trie, trie: trie,
err: errIteratorEnd, err: errIteratorEnd,
@ -110,7 +122,7 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator {
return it return it
} }
func (it *nodeIterator) AddResolver(resolver trie.NodeResolver) { func (it *nodeIterator) AddResolver(resolver NodeResolver) {
it.resolver = resolver it.resolver = resolver
} }
@ -128,6 +140,14 @@ func (it *nodeIterator) Parent() common.Hash {
return it.stack[len(it.stack)-1].parent return it.stack[len(it.stack)-1].parent
} }
func (it *nodeIterator) ParentPath() []byte {
if len(it.stack) == 0 {
return []byte{}
}
pathlen := it.stack[len(it.stack)-1].pathlen
return it.path[:pathlen]
}
func (it *nodeIterator) Leaf() bool { func (it *nodeIterator) Leaf() bool {
return hasTerm(it.path) return hasTerm(it.path)
} }
@ -241,7 +261,7 @@ func (it *nodeIterator) seek(prefix []byte) error {
func (it *nodeIterator) init() (*nodeIteratorState, error) { func (it *nodeIterator) init() (*nodeIteratorState, error) {
root := it.trie.Hash() root := it.trie.Hash()
state := &nodeIteratorState{node: it.trie.root, index: -1} state := &nodeIteratorState{node: it.trie.root, index: -1}
if root != emptyRoot { if root != types.EmptyRootHash {
state.hash = root state.hash = root
} }
return state, state.resolve(it, nil) return state, state.resolve(it, nil)
@ -320,7 +340,12 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) {
} }
} }
} }
return it.trie.resolveHash(hash, path) // Retrieve the specified node from the underlying node reader.
// it.trie.resolveAndTrack is not used since in that function the
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
return it.trie.reader.node(path, common.BytesToHash(hash))
} }
func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) {
@ -329,7 +354,12 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error)
return blob, nil return blob, nil
} }
} }
return it.trie.resolveBlob(hash, path) // Retrieve the specified node from the underlying node reader.
// it.trie.resolveAndTrack is not used since in that function the
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
return it.trie.reader.nodeBlob(path, common.BytesToHash(hash))
} }
func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
@ -455,3 +485,248 @@ func (it *nodeIterator) pop() {
it.stack[len(it.stack)-1] = nil it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1] it.stack = it.stack[:len(it.stack)-1]
} }
func compareNodes(a, b NodeIterator) int {
if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
return cmp
}
if a.Leaf() && !b.Leaf() {
return -1
} else if b.Leaf() && !a.Leaf() {
return 1
}
if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
return cmp
}
if a.Leaf() && b.Leaf() {
return bytes.Compare(a.LeafBlob(), b.LeafBlob())
}
return 0
}
type differenceIterator struct {
a, b NodeIterator // Nodes returned are those in b - a.
eof bool // Indicates a has run out of elements
count int // Number of nodes scanned on either trie
}
// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that
// are not in a. Returns the iterator, and a pointer to an integer recording the number
// of nodes seen.
func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) {
a.Next(true)
it := &differenceIterator{
a: a,
b: b,
}
return it, &it.count
}
func (it *differenceIterator) Hash() common.Hash {
return it.b.Hash()
}
func (it *differenceIterator) Parent() common.Hash {
return it.b.Parent()
}
func (it *differenceIterator) ParentPath() []byte {
return it.b.ParentPath()
}
func (it *differenceIterator) Leaf() bool {
return it.b.Leaf()
}
func (it *differenceIterator) LeafKey() []byte {
return it.b.LeafKey()
}
func (it *differenceIterator) LeafBlob() []byte {
return it.b.LeafBlob()
}
func (it *differenceIterator) LeafProof() [][]byte {
return it.b.LeafProof()
}
func (it *differenceIterator) Path() []byte {
return it.b.Path()
}
func (it *differenceIterator) NodeBlob() []byte {
return it.b.NodeBlob()
}
func (it *differenceIterator) AddResolver(resolver NodeResolver) {
panic("not implemented")
}
func (it *differenceIterator) Next(bool) bool {
defer metrics.UpdateDuration(time.Now(), metrics.IndexerMetrics.DifferenceIteratorNextTimer)
// Invariants:
// - We always advance at least one element in b.
// - At the start of this function, a's path is lexically greater than b's.
if !it.b.Next(true) {
return false
}
it.count++
if it.eof {
// a has reached eof, so we just return all elements from b
return true
}
for {
switch compareNodes(it.a, it.b) {
case -1:
// b jumped past a; advance a
if !it.a.Next(true) {
it.eof = true
return true
}
it.count++
case 1:
// b is before a
return true
case 0:
// a and b are identical; skip this whole subtree if the nodes have hashes
hasHash := it.a.Hash() == common.Hash{}
if !it.b.Next(hasHash) {
return false
}
it.count++
if !it.a.Next(hasHash) {
it.eof = true
return true
}
it.count++
}
}
}
func (it *differenceIterator) Error() error {
if err := it.a.Error(); err != nil {
return err
}
return it.b.Error()
}
type nodeIteratorHeap []NodeIterator
func (h nodeIteratorHeap) Len() int { return len(h) }
func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 }
func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) }
func (h *nodeIteratorHeap) Pop() interface{} {
n := len(*h)
x := (*h)[n-1]
*h = (*h)[0 : n-1]
return x
}
type unionIterator struct {
items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
count int // Number of nodes scanned across all tries
}
// NewUnionIterator constructs a NodeIterator that iterates over elements in the union
// of the provided NodeIterators. Returns the iterator, and a pointer to an integer
// recording the number of nodes visited.
func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
h := make(nodeIteratorHeap, len(iters))
copy(h, iters)
heap.Init(&h)
ui := &unionIterator{items: &h}
return ui, &ui.count
}
func (it *unionIterator) Hash() common.Hash {
return (*it.items)[0].Hash()
}
func (it *unionIterator) Parent() common.Hash {
return (*it.items)[0].Parent()
}
func (it *unionIterator) ParentPath() []byte {
return (*it.items)[0].ParentPath()
}
func (it *unionIterator) Leaf() bool {
return (*it.items)[0].Leaf()
}
func (it *unionIterator) LeafKey() []byte {
return (*it.items)[0].LeafKey()
}
func (it *unionIterator) LeafBlob() []byte {
return (*it.items)[0].LeafBlob()
}
func (it *unionIterator) LeafProof() [][]byte {
return (*it.items)[0].LeafProof()
}
func (it *unionIterator) Path() []byte {
return (*it.items)[0].Path()
}
func (it *unionIterator) NodeBlob() []byte {
return (*it.items)[0].NodeBlob()
}
func (it *unionIterator) AddResolver(resolver NodeResolver) {
panic("not implemented")
}
// Next returns the next node in the union of tries being iterated over.
//
// It does this by maintaining a heap of iterators, sorted by the iteration
// order of their next elements, with one entry for each source trie. Each
// time Next() is called, it takes the least element from the heap to return,
// advancing any other iterators that also point to that same element. These
// iterators are called with descend=false, since we know that any nodes under
// these nodes will also be duplicates, found in the currently selected iterator.
// Whenever an iterator is advanced, it is pushed back into the heap if it still
// has elements remaining.
//
// In the case that descend=false - eg, we're asked to ignore all subnodes of the
// current node - we also advance any iterators in the heap that have the current
// path as a prefix.
func (it *unionIterator) Next(descend bool) bool {
if len(*it.items) == 0 {
return false
}
// Get the next key from the union
least := heap.Pop(it.items).(NodeIterator)
// Skip over other nodes as long as they're identical, or, if we're not descending, as
// long as they have the same prefix as the current node.
for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) {
skipped := heap.Pop(it.items).(NodeIterator)
// Skip the whole subtree if the nodes have hashes; otherwise just skip this node
if skipped.Next(skipped.Hash() == common.Hash{}) {
it.count++
// If there are more elements, push the iterator back on the heap
heap.Push(it.items, skipped)
}
}
if least.Next(descend) {
it.count++
heap.Push(it.items, least)
}
return len(*it.items) > 0
}
func (it *unionIterator) Error() error {
for i := 0; i < len(*it.items); i++ {
if err := (*it.items)[i].Error(); err != nil {
return err
}
}
return nil
}

View File

@ -14,28 +14,24 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"bytes" "bytes"
"context" "encoding/binary"
"fmt" "fmt"
"math/rand"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
geth_trie "github.com/ethereum/go-ethereum/trie" geth_trie "github.com/ethereum/go-ethereum/trie"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
var ( var (
dbConfig, _ = postgres.DefaultConfig.WithEnv() packableTestData = []kvsi{
trieConfig = trie.Config{Cache: 256}
ctx = context.Background()
testdata0 = []kvs{
{"one", 1}, {"one", 1},
{"two", 2}, {"two", 2},
{"three", 3}, {"three", 3},
@ -43,20 +39,10 @@ var (
{"five", 5}, {"five", 5},
{"ten", 10}, {"ten", 10},
} }
testdata1 = []kvs{
{"barb", 0},
{"bard", 1},
{"bars", 2},
{"bar", 3},
{"fab", 4},
{"food", 5},
{"foos", 6},
{"foo", 7},
}
) )
func TestEmptyIterator(t *testing.T) { func TestEmptyIterator(t *testing.T) {
trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
iter := trie.NodeIterator(nil) iter := trie.NodeIterator(nil)
seen := make(map[string]struct{}) seen := make(map[string]struct{})
@ -69,63 +55,163 @@ func TestEmptyIterator(t *testing.T) {
} }
func TestIterator(t *testing.T) { func TestIterator(t *testing.T) {
edb := rawdb.NewMemoryDatabase() db := NewDatabase(rawdb.NewMemoryDatabase())
db := geth_trie.NewDatabase(edb) trie := NewEmpty(db)
origtrie := geth_trie.NewEmpty(db) vals := []struct{ k, v string }{
all, err := updateTrie(origtrie, testdata0) {"do", "verb"},
if err != nil { {"ether", "wookiedoo"},
t.Fatal(err) {"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"dog", "puppy"},
{"somethingveryoddindeedthis is", "myothernodedata"},
} }
// commit and index data all := make(map[string]string)
root := commitTrie(t, db, origtrie) for _, val := range vals {
tr := indexTrie(t, edb, root) all[val.k] = val.v
trie.Update([]byte(val.k), []byte(val.v))
}
root, nodes := trie.Commit(false)
db.Update(NewWithNodeSet(nodes))
found := make(map[string]int64) trie, _ = New(TrieID(root), db, StateTrieCodec)
it := trie.NewIterator(tr.NodeIterator(nil)) found := make(map[string]string)
it := NewIterator(trie.NodeIterator(nil))
for it.Next() { for it.Next() {
found[string(it.Key)] = unpackValue(it.Value) found[string(it.Key)] = string(it.Value)
} }
if len(found) != len(all) { for k, v := range all {
t.Errorf("number of iterated values do not match: want %d, found %d", len(all), len(found)) if found[k] != v {
} t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
for k, kv := range all {
if found[k] != kv.v {
t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], kv.v)
} }
} }
} }
type kv struct {
k, v []byte
t bool
}
func TestIteratorLargeData(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
it := NewIterator(trie.NodeIterator(nil))
for it.Next() {
vals[string(it.Key)].t = true
}
var untouched []*kv
for _, value := range vals {
if !value.t {
untouched = append(untouched, value)
}
}
if len(untouched) > 0 {
t.Errorf("Missed %d nodes", len(untouched))
for _, value := range untouched {
t.Error(value)
}
}
}
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) {
db, trie, _ := makeTestTrie(t)
// Create some arbitrary test trie to iterate
// Gather all the node hashes found by the iterator
hashes := make(map[common.Hash]struct{})
for it := trie.NodeIterator(nil); it.Next(true); {
if it.Hash() != (common.Hash{}) {
hashes[it.Hash()] = struct{}{}
}
}
// Cross check the hashes and the database itself
for hash := range hashes {
if _, err := db.Node(hash, StateTrieCodec); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err)
}
}
for hash, obj := range db.dirties {
if obj != nil && hash != (common.Hash{}) {
if _, ok := hashes[hash]; !ok {
t.Errorf("state entry not reported %x", hash)
}
}
}
it := db.diskdb.NewIterator(nil, nil)
for it.Next() {
key := it.Key()
if _, ok := hashes[common.BytesToHash(key)]; !ok {
t.Errorf("state entry not reported %x", key)
}
}
it.Release()
}
type kvs struct{ k, v string }
var testdata1 = []kvs{
{"barb", "ba"},
{"bard", "bc"},
{"bars", "bb"},
{"bar", "b"},
{"fab", "z"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
}
var testdata2 = []kvs{
{"aardvark", "c"},
{"bar", "b"},
{"barb", "bd"},
{"bars", "be"},
{"fab", "z"},
{"foo", "a"},
{"foos", "aa"},
{"food", "ab"},
{"jars", "d"},
}
func TestIteratorSeek(t *testing.T) { func TestIteratorSeek(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb) for _, val := range testdata1 {
orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) trie.Update([]byte(val.k), []byte(val.v))
if _, err := updateTrie(orig, testdata1); err != nil {
t.Fatal(err)
} }
root := commitTrie(t, db, orig)
tr := indexTrie(t, edb, root)
// Seek to the middle. // Seek to the middle.
it := trie.NewIterator(tr.NodeIterator([]byte("fab"))) it := NewIterator(trie.NodeIterator([]byte("fab")))
if err := checkIteratorOrder(testdata1[4:], it); err != nil { if err := checkIteratorOrder(testdata1[4:], it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Seek to a non-existent key. // Seek to a non-existent key.
it = trie.NewIterator(tr.NodeIterator([]byte("barc"))) it = NewIterator(trie.NodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil { if err := checkIteratorOrder(testdata1[1:], it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Seek beyond the end. // Seek beyond the end.
it = trie.NewIterator(tr.NodeIterator([]byte("z"))) it = NewIterator(trie.NodeIterator([]byte("z")))
if err := checkIteratorOrder(nil, it); err != nil { if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
func checkIteratorOrder(want []kvs, it *trie.Iterator) error { func checkIteratorOrder(want []kvs, it *Iterator) error {
for it.Next() { for it.Next() {
if len(want) == 0 { if len(want) == 0 {
return fmt.Errorf("didn't expect any more values, got key %q", it.Key) return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
@ -141,11 +227,366 @@ func checkIteratorOrder(want []kvs, it *trie.Iterator) error {
return nil return nil
} }
func TestDifferenceIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase())
triea := NewEmpty(dba)
for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v))
}
rootA, nodesA := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA))
triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
dbb := NewDatabase(rawdb.NewMemoryDatabase())
trieb := NewEmpty(dbb)
for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v))
}
rootB, nodesB := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB))
trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
found := make(map[string]string)
di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
it := NewIterator(di)
for it.Next() {
found[string(it.Key)] = string(it.Value)
}
all := []struct{ k, v string }{
{"aardvark", "c"},
{"barb", "bd"},
{"bars", "be"},
{"jars", "d"},
}
for _, item := range all {
if found[item.k] != item.v {
t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
}
}
if len(found) != len(all) {
t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
}
}
func TestUnionIterator(t *testing.T) {
dba := NewDatabase(rawdb.NewMemoryDatabase())
triea := NewEmpty(dba)
for _, val := range testdata1 {
triea.Update([]byte(val.k), []byte(val.v))
}
rootA, nodesA := triea.Commit(false)
dba.Update(NewWithNodeSet(nodesA))
triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
dbb := NewDatabase(rawdb.NewMemoryDatabase())
trieb := NewEmpty(dbb)
for _, val := range testdata2 {
trieb.Update([]byte(val.k), []byte(val.v))
}
rootB, nodesB := trieb.Commit(false)
dbb.Update(NewWithNodeSet(nodesB))
trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
it := NewIterator(di)
all := []struct{ k, v string }{
{"aardvark", "c"},
{"barb", "ba"},
{"barb", "bd"},
{"bard", "bc"},
{"bars", "bb"},
{"bars", "be"},
{"bar", "b"},
{"fab", "z"},
{"food", "ab"},
{"foos", "aa"},
{"foo", "a"},
{"jars", "d"},
}
for i, kv := range all {
if !it.Next() {
t.Errorf("Iterator ends prematurely at element %d", i)
}
if kv.k != string(it.Key) {
t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
}
if kv.v != string(it.Value) {
t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
}
}
if it.Next() {
t.Errorf("Iterator returned extra values.")
}
}
func TestIteratorNoDups(t *testing.T) {
tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
}
// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
func testIteratorContinueAfterError(t *testing.T, memonly bool) {
diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
tr := NewEmpty(triedb)
for _, val := range testdata1 {
tr.Update([]byte(val.k), []byte(val.v))
}
_, nodes := tr.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// if !memonly {
// triedb.Commit(tr.Hash(), false)
// }
wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
var (
diskKeys [][]byte
memKeys []common.Hash
)
if memonly {
memKeys = triedb.Nodes()
} else {
it := diskdb.NewIterator(nil, nil)
for it.Next() {
diskKeys = append(diskKeys, it.Key())
}
it.Release()
}
for i := 0; i < 20; i++ {
// Create trie that will load all nodes from DB.
tr, _ := New(TrieID(tr.Hash()), triedb, StateTrieCodec)
// Remove a random node from the database. It can't be the root node
// because that one is already loaded.
var (
rkey common.Hash
rval []byte
robj *cachedNode
)
for {
if memonly {
rkey = memKeys[rand.Intn(len(memKeys))]
} else {
copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
}
if rkey != tr.Hash() {
break
}
}
if memonly {
robj = triedb.dirties[rkey]
delete(triedb.dirties, rkey)
} else {
rval, _ = diskdb.Get(rkey[:])
diskdb.Delete(rkey[:])
}
// Iterate until the error is hit.
seen := make(map[string]bool)
it := tr.NodeIterator(nil)
checkIteratorNoDups(t, it, seen)
missing, ok := it.Error().(*MissingNodeError)
if !ok || missing.NodeHash != rkey {
t.Fatal("didn't hit missing node, got", it.Error())
}
// Add the node back and continue iteration.
if memonly {
triedb.dirties[rkey] = robj
} else {
diskdb.Put(rkey[:], rval)
}
checkIteratorNoDups(t, it, seen)
if it.Error() != nil {
t.Fatal("unexpected error", it.Error())
}
if len(seen) != wantNodeCount {
t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
}
}
}
// Similar to the test above, this one checks that failure to create nodeIterator at a
// certain key prefix behaves correctly when Next is called. The expectation is that Next
// should retry seeking before returning true for the first time.
func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
testIteratorContinueAfterSeekError(t, true)
}
func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
// Commit test trie to db, then remove the node containing "bars".
diskdb := rawdb.NewMemoryDatabase()
triedb := NewDatabase(diskdb)
ctr := NewEmpty(triedb)
for _, val := range testdata1 {
ctr.Update([]byte(val.k), []byte(val.v))
}
root, nodes := ctr.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// if !memonly {
// triedb.Commit(root, false)
// }
barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
var (
barNodeBlob []byte
barNodeObj *cachedNode
)
if memonly {
barNodeObj = triedb.dirties[barNodeHash]
delete(triedb.dirties, barNodeHash)
} else {
barNodeBlob, _ = diskdb.Get(barNodeHash[:])
diskdb.Delete(barNodeHash[:])
}
// Create a new iterator that seeks to "bars". Seeking can't proceed because
// the node is missing.
tr, _ := New(TrieID(root), triedb, StateTrieCodec)
it := tr.NodeIterator([]byte("bars"))
missing, ok := it.Error().(*MissingNodeError)
if !ok {
t.Fatal("want MissingNodeError, got", it.Error())
} else if missing.NodeHash != barNodeHash {
t.Fatal("wrong node missing")
}
// Reinsert the missing node.
if memonly {
triedb.dirties[barNodeHash] = barNodeObj
} else {
diskdb.Put(barNodeHash[:], barNodeBlob)
}
// Check that iteration produces the right set of values.
if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
t.Fatal(err)
}
}
func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
if seen == nil {
seen = make(map[string]bool)
}
for it.Next(true) {
if seen[string(it.Path())] {
t.Fatalf("iterator visited node path %x twice", it.Path())
}
seen[string(it.Path())] = true
}
return len(seen)
}
type loggingTrieDb struct {
*Database
getCount uint64
}
// GetReader retrieves a node reader belonging to the given state root.
func (db *loggingTrieDb) GetReader(root common.Hash, codec uint64) Reader {
return &loggingNodeReader{db, codec}
}
// hashReader is reader of hashDatabase which implements the Reader interface.
type loggingNodeReader struct {
db *loggingTrieDb
codec uint64
}
// Node retrieves the trie node with the given node hash.
func (reader *loggingNodeReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
blob, err := reader.NodeBlob(owner, path, hash)
if err != nil {
return nil, err
}
return decodeNodeUnsafe(hash[:], blob)
}
// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
func (reader *loggingNodeReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
reader.db.getCount++
return reader.db.Node(hash, reader.codec)
}
func newLoggingStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, *loggingTrieDb, error) {
logdb := &loggingTrieDb{Database: db}
trie, err := New(id, logdb, codec)
if err != nil {
return nil, nil, err
}
return &StateTrie{trie: *trie, preimages: db.preimages}, logdb, nil
}
// makeLargeTestTrie create a sample test trie
func makeLargeTestTrie(t testing.TB) (*Database, *StateTrie, *loggingTrieDb) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewDatabase(memorydb.New()))
trie, logDb, err := newLoggingStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
// Fill it with some arbitrary data
for i := 0; i < 10000; i++ {
key := make([]byte, 32)
val := make([]byte, 32)
binary.BigEndian.PutUint64(key, uint64(i))
binary.BigEndian.PutUint64(val, uint64(i))
key = crypto.Keccak256(key)
val = crypto.Keccak256(val)
trie.Update(key, val)
}
_, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// Return the generated trie
return triedb, trie, logDb
}
// Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorLargeTrie(t *testing.T) {
// Create some arbitrary test trie to iterate
_, trie, logDb := makeLargeTestTrie(t)
// Do a seek operation
trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
// master: 24 get operations
// this pr: 5 get operations
if have, want := logDb.getCount, uint64(5); have != want {
t.Fatalf("Wrong number of lookups during seek, have %d want %d", have, want)
}
}
func TestIteratorNodeBlob(t *testing.T) { func TestIteratorNodeBlob(t *testing.T) {
// var (
// db = rawdb.NewMemoryDatabase()
// triedb = NewDatabase(db)
// trie = NewEmpty(triedb)
// )
// vals := []struct{ k, v string }{
// {"do", "verb"},
// {"ether", "wookiedoo"},
// {"horse", "stallion"},
// {"shaman", "horse"},
// {"doge", "coin"},
// {"dog", "puppy"},
// {"somethingveryoddindeedthis is", "myothernodedata"},
// }
// all := make(map[string]string)
// for _, val := range vals {
// all[val.k] = val.v
// trie.Update([]byte(val.k), []byte(val.v))
// }
// _, nodes := trie.Commit(false)
// triedb.Update(NewWithNodeSet(nodes))
edb := rawdb.NewMemoryDatabase() edb := rawdb.NewMemoryDatabase()
db := geth_trie.NewDatabase(edb) db := geth_trie.NewDatabase(edb)
orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase()))
if _, err := updateTrie(orig, testdata1); err != nil { if _, err := updateTrie(orig, packableTestData); err != nil {
t.Fatal(err) t.Fatal(err)
} }
root := commitTrie(t, db, orig) root := commitTrie(t, db, orig)
@ -167,7 +608,7 @@ func TestIteratorNodeBlob(t *testing.T) {
for dbIter.Next() { for dbIter.Next() {
got, present := found[common.BytesToHash(dbIter.Key())] got, present := found[common.BytesToHash(dbIter.Key())]
if !present { if !present {
t.Fatalf("Missing trie node %v", dbIter.Key()) t.Fatalf("Miss trie node %v", dbIter.Key())
} }
if !bytes.Equal(got, dbIter.Value()) { if !bytes.Equal(got, dbIter.Value()) {
t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got) t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
@ -175,6 +616,6 @@ func TestIteratorNodeBlob(t *testing.T) {
count += 1 count += 1
} }
if count != len(found) { if count != len(found) {
t.Fatal("Find extra trie node via iterator") t.Fatalf("Wrong number of trie nodes found, want %d, got %d", len(found), count)
} }
} }

View File

@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
) )
var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
@ -157,7 +156,7 @@ func decodeShort(hash, elems []byte) (node, error) {
return nil, err return nil, err
} }
flag := nodeFlag{hash: hash} flag := nodeFlag{hash: hash}
key := trie.CompactToHex(kbuf) key := compactToHex(kbuf)
if hasTerm(key) { if hasTerm(key) {
// value node // value node
val, _, err := rlp.SplitString(rest) val, _, err := rlp.SplitString(rest)

View File

@ -58,3 +58,30 @@ func (n hashNode) encode(w rlp.EncoderBuffer) {
func (n valueNode) encode(w rlp.EncoderBuffer) { func (n valueNode) encode(w rlp.EncoderBuffer) {
w.WriteBytes(n) w.WriteBytes(n)
} }
func (n rawFullNode) encode(w rlp.EncoderBuffer) {
offset := w.List()
for _, c := range n {
if c != nil {
c.encode(w)
} else {
w.Write(rlp.EmptyString)
}
}
w.ListEnd(offset)
}
func (n *rawShortNode) encode(w rlp.EncoderBuffer) {
offset := w.List()
w.WriteBytes(n.Key)
if n.Val != nil {
n.Val.encode(w)
} else {
w.Write(rlp.EmptyString)
}
w.ListEnd(offset)
}
func (n rawNode) encode(w rlp.EncoderBuffer) {
w.Write(n)
}

View File

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
) )
@ -92,3 +93,123 @@ func TestDecodeFullNode(t *testing.T) {
t.Fatalf("decode full node err: %v", err) t.Fatalf("decode full node err: %v", err)
} }
} }
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkEncodeShortNode
// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op
func BenchmarkEncodeShortNode(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
nodeToBytes(node)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkEncodeFullNode
// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op
func BenchmarkEncodeFullNode(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
nodeToBytes(node)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeShortNode
// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op
func BenchmarkDecodeShortNode(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNode(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeShortNodeUnsafe
// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op
func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
node := &shortNode{
Key: []byte{0x1, 0x2},
Val: hashNode(randBytes(32)),
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNodeUnsafe(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeFullNode
// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op
func BenchmarkDecodeFullNode(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNode(hash, blob)
}
}
// goos: darwin
// goarch: arm64
// pkg: github.com/ethereum/go-ethereum/trie
// BenchmarkDecodeFullNodeUnsafe
// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op
func BenchmarkDecodeFullNodeUnsafe(b *testing.B) {
node := &fullNode{}
for i := 0; i < 16; i++ {
node.Children[i] = hashNode(randBytes(32))
}
blob := nodeToBytes(node)
hash := crypto.Keccak256(blob)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mustDecodeNodeUnsafe(hash, blob)
}
}

218
trie_by_cid/trie/nodeset.go Normal file
View File

@ -0,0 +1,218 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"fmt"
"reflect"
"sort"
"strings"
"github.com/ethereum/go-ethereum/common"
)
// memoryNode is all the information we know about a single cached trie node
// in the memory.
type memoryNode struct {
hash common.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes
size uint16 // Byte size of the useful cached data, 0 for deleted nodes
node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes
}
// memoryNodeSize is the raw size of a memoryNode data structure without any
// node data included. It's an approximate size, but should be a lot better
// than not counting them.
// nolint:unused
var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size())
// memorySize returns the total memory size used by this node.
// nolint:unused
func (n *memoryNode) memorySize(pathlen int) int {
return int(n.size) + memoryNodeSize + pathlen
}
// rlp returns the raw rlp encoded blob of the cached trie node, either directly
// from the cache, or by regenerating it from the collapsed node.
// nolint:unused
func (n *memoryNode) rlp() []byte {
if node, ok := n.node.(rawNode); ok {
return node
}
return nodeToBytes(n.node)
}
// obj returns the decoded and expanded trie node, either directly from the cache,
// or by regenerating it from the rlp encoded blob.
// nolint:unused
func (n *memoryNode) obj() node {
if node, ok := n.node.(rawNode); ok {
return mustDecodeNode(n.hash[:], node)
}
return expandNode(n.hash[:], n.node)
}
// isDeleted returns the indicator if the node is marked as deleted.
func (n *memoryNode) isDeleted() bool {
return n.hash == (common.Hash{})
}
// nodeWithPrev wraps the memoryNode with the previous node value.
// nolint: unused
type nodeWithPrev struct {
*memoryNode
prev []byte // RLP-encoded previous value, nil means it's non-existent
}
// unwrap returns the internal memoryNode object.
// nolint:unused
func (n *nodeWithPrev) unwrap() *memoryNode {
return n.memoryNode
}
// memorySize returns the total memory size used by this node. It overloads
// the function in memoryNode by counting the size of previous value as well.
// nolint: unused
func (n *nodeWithPrev) memorySize(pathlen int) int {
return n.memoryNode.memorySize(pathlen) + len(n.prev)
}
// NodeSet contains all dirty nodes collected during the commit operation.
// Each node is keyed by path. It's not thread-safe to use.
type NodeSet struct {
owner common.Hash // the identifier of the trie
nodes map[string]*memoryNode // the set of dirty nodes(inserted, updated, deleted)
leaves []*leaf // the list of dirty leaves
updates int // the count of updated and inserted nodes
deletes int // the count of deleted nodes
// The list of accessed nodes, which records the original node value.
// The origin value is expected to be nil for newly inserted node
// and is expected to be non-nil for other types(updated, deleted).
accessList map[string][]byte
}
// NewNodeSet initializes an empty node set to be used for tracking dirty nodes
// from a specific account or storage trie. The owner is zero for the account
// trie and the owning account address hash for storage tries.
func NewNodeSet(owner common.Hash, accessList map[string][]byte) *NodeSet {
return &NodeSet{
owner: owner,
nodes: make(map[string]*memoryNode),
accessList: accessList,
}
}
// forEachWithOrder iterates the dirty nodes with the order from bottom to top,
// right to left, nodes with the longest path will be iterated first.
func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) {
var paths sort.StringSlice
for path := range set.nodes {
paths = append(paths, path)
}
// Bottom-up, longest path first
sort.Sort(sort.Reverse(paths))
for _, path := range paths {
callback(path, set.nodes[path])
}
}
// markUpdated marks the node as dirty(newly-inserted or updated).
func (set *NodeSet) markUpdated(path []byte, node *memoryNode) {
set.nodes[string(path)] = node
set.updates += 1
}
// markDeleted marks the node as deleted.
func (set *NodeSet) markDeleted(path []byte) {
set.nodes[string(path)] = &memoryNode{}
set.deletes += 1
}
// addLeaf collects the provided leaf node into set.
func (set *NodeSet) addLeaf(node *leaf) {
set.leaves = append(set.leaves, node)
}
// Size returns the number of dirty nodes in set.
func (set *NodeSet) Size() (int, int) {
return set.updates, set.deletes
}
// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can
// we get rid of it?
func (set *NodeSet) Hashes() []common.Hash {
var ret []common.Hash
for _, node := range set.nodes {
ret = append(ret, node.hash)
}
return ret
}
// Summary returns a string-representation of the NodeSet.
func (set *NodeSet) Summary() string {
var out = new(strings.Builder)
fmt.Fprintf(out, "nodeset owner: %v\n", set.owner)
if set.nodes != nil {
for path, n := range set.nodes {
// Deletion
if n.isDeleted() {
fmt.Fprintf(out, " [-]: %x prev: %x\n", path, set.accessList[path])
continue
}
// Insertion
origin, ok := set.accessList[path]
if !ok {
fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash)
continue
}
// Update
fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, origin)
}
}
for _, n := range set.leaves {
fmt.Fprintf(out, "[leaf]: %v\n", n)
}
return out.String()
}
// MergedNodeSet represents a merged dirty node set for a group of tries.
type MergedNodeSet struct {
sets map[common.Hash]*NodeSet
}
// NewMergedNodeSet initializes an empty merged set.
func NewMergedNodeSet() *MergedNodeSet {
return &MergedNodeSet{sets: make(map[common.Hash]*NodeSet)}
}
// NewWithNodeSet constructs a merged nodeset with the provided single set.
func NewWithNodeSet(set *NodeSet) *MergedNodeSet {
merged := NewMergedNodeSet()
merged.Merge(set)
return merged
}
// Merge merges the provided dirty nodes of a trie into the set. The assumption
// is held that no duplicated set belonging to the same trie will be merged twice.
func (set *MergedNodeSet) Merge(other *NodeSet) error {
_, present := set.sets[other.owner]
if present {
return fmt.Errorf("duplicate trie for owner %#x", other.owner)
}
set.sets[other.owner] = other
return nil
}

View File

@ -0,0 +1,78 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/ethdb"
)
// preimageStore is the store for caching preimages of node key.
type preimageStore struct {
lock sync.RWMutex
disk ethdb.KeyValueStore
preimages map[common.Hash][]byte // Preimages of nodes from the secure trie
preimagesSize common.StorageSize // Storage size of the preimages cache
}
// newPreimageStore initializes the store for caching preimages.
func newPreimageStore(disk ethdb.KeyValueStore) *preimageStore {
return &preimageStore{
disk: disk,
preimages: make(map[common.Hash][]byte),
}
}
// insertPreimage writes a new trie node pre-image to the memory database if it's
// yet unknown. The method will NOT make a copy of the slice, only use if the
// preimage will NOT be changed later on.
func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) {
store.lock.Lock()
defer store.lock.Unlock()
for hash, preimage := range preimages {
if _, ok := store.preimages[hash]; ok {
continue
}
store.preimages[hash] = preimage
store.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
}
}
// preimage retrieves a cached trie node pre-image from memory. If it cannot be
// found cached, the method queries the persistent database for the content.
func (store *preimageStore) preimage(hash common.Hash) []byte {
store.lock.RLock()
preimage := store.preimages[hash]
store.lock.RUnlock()
if preimage != nil {
return preimage
}
return rawdb.ReadPreimage(store.disk, hash)
}
// size returns the current storage size of accumulated preimages.
func (store *preimageStore) size() common.StorageSize {
store.lock.RLock()
defer store.lock.RUnlock()
return store.preimagesSize
}

View File

@ -20,9 +20,10 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
log "github.com/sirupsen/logrus"
) )
var VerifyProof = trie.VerifyProof var VerifyProof = trie.VerifyProof
@ -61,10 +62,15 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
key = key[1:] key = key[1:]
nodes = append(nodes, n) nodes = append(nodes, n)
case hashNode: case hashNode:
// Retrieve the specified node from the underlying node reader.
// trie.resolveAndTrack is not used since in that function the
// loaded blob will be tracked, while it's not required here since
// all loaded nodes won't be linked to trie at all and track nodes
// may lead to out-of-memory issue.
var err error var err error
tn, err = t.resolveHash(n, prefix) tn, err = t.reader.node(prefix, common.BytesToHash(n))
if err != nil { if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) log.Error("Unhandled trie error in Trie.Prove", "err", err)
return err return err
} }
default: default:

View File

@ -14,29 +14,39 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"bytes" "bytes"
crand "crypto/rand"
"encoding/binary"
"fmt"
mrand "math/rand" mrand "math/rand"
"sort" "sort"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/ethdb/memorydb"
geth_trie "github.com/ethereum/go-ethereum/trie"
. "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
// tweakable trie size // Prng is a pseudo random number generator seeded by strong randomness.
var scaleFactor = 512 // The randomness is printed on startup in order to make failures reproducible.
var prng = initRnd()
func init() { func initRnd() *mrand.Rand {
mrand.Seed(time.Now().UnixNano()) var seed [8]byte
crand.Read(seed[:])
rnd := mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(seed[:]))))
fmt.Printf("Seed: %x\n", seed)
return rnd
}
func randBytes(n int) []byte {
r := make([]byte, n)
prng.Read(r)
return r
} }
// makeProvers creates Merkle trie provers based on different implementations to // makeProvers creates Merkle trie provers based on different implementations to
@ -64,7 +74,7 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
} }
func TestProof(t *testing.T) { func TestProof(t *testing.T) {
trie, vals := randomTrie(t, scaleFactor) trie, vals := randomTrie(500)
root := trie.Hash() root := trie.Hash()
for i, prover := range makeProvers(trie) { for i, prover := range makeProvers(trie) {
for _, kv := range vals { for _, kv := range vals {
@ -72,11 +82,11 @@ func TestProof(t *testing.T) {
if proof == nil { if proof == nil {
t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
} }
val, err := geth_trie.VerifyProof(root, kv.k, proof) val, err := VerifyProof(root, kv.k, proof)
if err != nil { if err != nil {
t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof)
} }
if kv.v != unpackValue(val) { if !bytes.Equal(val, kv.v) {
t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v) t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v)
} }
} }
@ -84,13 +94,8 @@ func TestProof(t *testing.T) {
} }
func TestOneElementProof(t *testing.T) { func TestOneElementProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb) updateString(trie, "k", "v")
orig := geth_trie.NewEmpty(db)
orig.Update([]byte("k"), packValue(42))
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
for i, prover := range makeProvers(trie) { for i, prover := range makeProvers(trie) {
proof := prover([]byte("k")) proof := prover([]byte("k"))
if proof == nil { if proof == nil {
@ -99,18 +104,18 @@ func TestOneElementProof(t *testing.T) {
if proof.Len() != 1 { if proof.Len() != 1 {
t.Errorf("prover %d: proof should have one element", i) t.Errorf("prover %d: proof should have one element", i)
} }
val, err := geth_trie.VerifyProof(trie.Hash(), []byte("k"), proof) val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
if err != nil { if err != nil {
t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
} }
if 42 != unpackValue(val) { if !bytes.Equal(val, []byte("v")) {
t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val) t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val)
} }
} }
} }
func TestBadProof(t *testing.T) { func TestBadProof(t *testing.T) {
trie, vals := randomTrie(t, 2*scaleFactor) trie, vals := randomTrie(800)
root := trie.Hash() root := trie.Hash()
for i, prover := range makeProvers(trie) { for i, prover := range makeProvers(trie) {
for _, kv := range vals { for _, kv := range vals {
@ -140,12 +145,8 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single // Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry. // entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) { func TestMissingKeyProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb) updateString(trie, "k", "v")
orig := geth_trie.NewEmpty(db)
orig.Update([]byte("k"), packValue(42))
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
for i, key := range []string{"a", "j", "l", "z"} { for i, key := range []string{"a", "j", "l", "z"} {
proof := memorydb.New() proof := memorydb.New()
@ -164,15 +165,7 @@ func TestMissingKeyProof(t *testing.T) {
} }
} }
type entry struct { type entrySlice []*kv
k, v []byte
}
func packEntry(kv *kv) *entry {
return &entry{kv.k, packValue(kv.v)}
}
type entrySlice []*entry
func (p entrySlice) Len() int { return len(p) } func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 } func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
@ -181,10 +174,10 @@ func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// TestRangeProof tests normal range proof with both edge proofs // TestRangeProof tests normal range proof with both edge proofs
// as the existent proof. The test cases are generated randomly. // as the existent proof. The test cases are generated randomly.
func TestRangeProof(t *testing.T) { func TestRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
@ -214,10 +207,10 @@ func TestRangeProof(t *testing.T) {
// TestRangeProof tests normal range proof with two non-existent proofs. // TestRangeProof tests normal range proof with two non-existent proofs.
// The test cases are generated randomly. // The test cases are generated randomly.
func TestRangeProofWithNonExistentProof(t *testing.T) { func TestRangeProofWithNonExistentProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
for i := 0; i < 500; i++ { for i := 0; i < 500; i++ {
@ -286,10 +279,10 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
// - There exists a gap between the first element and the left edge proof // - There exists a gap between the first element and the left edge proof
// - There exists a gap between the last element and the right edge proof // - There exists a gap between the last element and the right edge proof
func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -343,10 +336,10 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
// element. The first edge proof can be existent one or // element. The first edge proof can be existent one or
// non-existent one. // non-existent one.
func TestOneElementRangeProof(t *testing.T) { func TestOneElementRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -408,14 +401,9 @@ func TestOneElementRangeProof(t *testing.T) {
} }
// Test the mini trie with only a single element. // Test the mini trie with only a single element.
t.Run("single element", func(t *testing.T) { tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
edb := rawdb.NewMemoryDatabase() entry := &kv{randBytes(32), randBytes(20), false}
db := geth_trie.NewDatabase(edb) tinyTrie.Update(entry.k, entry.v)
orig := geth_trie.NewEmpty(db)
entry := &entry{randBytes(32), packValue(mrand.Int63())}
orig.Update(entry.k, entry.v)
root := commitTrie(t, db, orig)
tinyTrie := indexTrie(t, edb, root)
first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
last = entry.k last = entry.k
@ -430,16 +418,15 @@ func TestOneElementRangeProof(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
})
} }
// TestAllElementsProof tests the range proof with all elements. // TestAllElementsProof tests the range proof with all elements.
// The edge proofs can be nil. // The edge proofs can be nil.
func TestAllElementsProof(t *testing.T) { func TestAllElementsProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -483,33 +470,20 @@ func TestAllElementsProof(t *testing.T) {
} }
} }
func positionCases(size int) []int {
var cases []int
for _, pos := range []int{0, 1, 50, 100, 1000, 2000, size - 1} {
if pos >= size {
break
}
cases = append(cases, pos)
}
return cases
}
// TestSingleSideRangeProof tests the range starts from zero. // TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) { func TestSingleSideRangeProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() for i := 0; i < 64; i++ {
db := geth_trie.NewDatabase(edb) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
orig := geth_trie.NewEmpty(db)
var entries entrySlice var entries entrySlice
for i := 0; i < 8*scaleFactor; i++ { for i := 0; i < 4096; i++ {
value := &entry{randBytes(32), packValue(mrand.Int63())} value := &kv{randBytes(32), randBytes(20), false}
orig.Update(value.k, value.v) trie.Update(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
sort.Sort(entries) sort.Sort(entries)
for _, pos := range positionCases(len(entries)) { var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases {
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
@ -529,23 +503,22 @@ func TestSingleSideRangeProof(t *testing.T) {
} }
} }
} }
}
// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. // TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
func TestReverseSingleSideRangeProof(t *testing.T) { func TestReverseSingleSideRangeProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() for i := 0; i < 64; i++ {
db := geth_trie.NewDatabase(edb) trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
orig := geth_trie.NewEmpty(db)
var entries entrySlice var entries entrySlice
for i := 0; i < 8*scaleFactor; i++ { for i := 0; i < 4096; i++ {
value := &entry{randBytes(32), packValue(mrand.Int63())} value := &kv{randBytes(32), randBytes(20), false}
orig.Update(value.k, value.v) trie.Update(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
sort.Sort(entries) sort.Sort(entries)
for _, pos := range positionCases(len(entries)) { var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
for _, pos := range cases {
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[pos].k, 0, proof); err != nil { if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
@ -566,14 +539,15 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
} }
} }
} }
}
// TestBadRangeProof tests a few cases which the proof is wrong. // TestBadRangeProof tests a few cases which the proof is wrong.
// The prover is expected to detect the error. // The prover is expected to detect the error.
func TestBadRangeProof(t *testing.T) { func TestBadRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -641,17 +615,13 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes. // TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too. // If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) { func TestGappedRangeProof(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb) var entries []*kv // Sorted entries
orig := geth_trie.NewEmpty(db)
var entries entrySlice
for i := byte(0); i < 10; i++ { for i := byte(0); i < 10; i++ {
value := &entry{common.LeftPadBytes([]byte{i}, 32), packValue(int64(i))} value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
orig.Update(value.k, value.v) trie.Update(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
first, last := 2, 8 first, last := 2, 8
proof := memorydb.New() proof := memorydb.New()
if err := trie.Prove(entries[first].k, 0, proof); err != nil { if err := trie.Prove(entries[first].k, 0, proof); err != nil {
@ -677,10 +647,10 @@ func TestGappedRangeProof(t *testing.T) {
// TestSameSideProofs tests the element is not in the range covered by proofs // TestSameSideProofs tests the element is not in the range covered by proofs
func TestSameSideProofs(t *testing.T) { func TestSameSideProofs(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -719,17 +689,13 @@ func TestSameSideProofs(t *testing.T) {
} }
func TestHasRightElement(t *testing.T) { func TestHasRightElement(t *testing.T) {
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
orig := geth_trie.NewEmpty(db)
var entries entrySlice var entries entrySlice
for i := 0; i < 8*scaleFactor; i++ { for i := 0; i < 4096; i++ {
value := &entry{randBytes(32), packValue(int64(i))} value := &kv{randBytes(32), randBytes(20), false}
orig.Update(value.k, value.v) trie.Update(value.k, value.v)
entries = append(entries, value) entries = append(entries, value)
} }
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
sort.Sort(entries) sort.Sort(entries)
var cases = []struct { var cases = []struct {
@ -797,10 +763,10 @@ func TestHasRightElement(t *testing.T) {
// TestEmptyRangeProof tests the range proof with "no" element. // TestEmptyRangeProof tests the range proof with "no" element.
// The first edge proof must be a non-existent proof. // The first edge proof must be a non-existent proof.
func TestEmptyRangeProof(t *testing.T) { func TestEmptyRangeProof(t *testing.T) {
trie, vals := randomTrie(t, 8*scaleFactor) trie, vals := randomTrie(4096)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -827,14 +793,49 @@ func TestEmptyRangeProof(t *testing.T) {
} }
} }
// TestBloatedProof tests a malicious proof, where the proof is more or less the
// whole trie. Previously we didn't accept such packets, but the new APIs do, so
// lets leave this test as a bit weird, but present.
func TestBloatedProof(t *testing.T) {
// Use a small trie
trie, kvs := nonRandomTrie(100)
var entries entrySlice
for _, kv := range kvs {
entries = append(entries, kv)
}
sort.Sort(entries)
var keys [][]byte
var vals [][]byte
proof := memorydb.New()
// In the 'malicious' case, we add proofs for every single item
// (but only one key/value pair used as leaf)
for i, entry := range entries {
trie.Prove(entry.k, 0, proof)
if i == 50 {
keys = append(keys, entry.k)
vals = append(vals, entry.v)
}
}
// For reference, we use the same function, but _only_ prove the first
// and last element
want := memorydb.New()
trie.Prove(keys[0], 0, want)
trie.Prove(keys[len(keys)-1], 0, want)
if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil {
t.Fatalf("expected bloated proof to succeed, got %v", err)
}
}
// TestEmptyValueRangeProof tests normal range proof with both edge proofs // TestEmptyValueRangeProof tests normal range proof with both edge proofs
// as the existent proof, but with an extra empty value included, which is a // as the existent proof, but with an extra empty value included, which is a
// noop technically, but practically should be rejected. // noop technically, but practically should be rejected.
func TestEmptyValueRangeProof(t *testing.T) { func TestEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(t, scaleFactor) trie, values := randomTrie(512)
var entries entrySlice var entries entrySlice
for _, kv := range values { for _, kv := range values {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -847,8 +848,8 @@ func TestEmptyValueRangeProof(t *testing.T) {
break break
} }
} }
noop := &entry{key, []byte{}} noop := &kv{key, []byte{}, false}
entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
start, end := 1, len(entries)-1 start, end := 1, len(entries)-1
@ -875,10 +876,10 @@ func TestEmptyValueRangeProof(t *testing.T) {
// but with an extra empty value included, which is a noop technically, but // but with an extra empty value included, which is a noop technically, but
// practically should be rejected. // practically should be rejected.
func TestAllElementsEmptyValueRangeProof(t *testing.T) { func TestAllElementsEmptyValueRangeProof(t *testing.T) {
trie, values := randomTrie(t, scaleFactor) trie, values := randomTrie(512)
var entries entrySlice var entries entrySlice
for _, kv := range values { for _, kv := range values {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -891,8 +892,8 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) {
break break
} }
} }
noop := &entry{key, nil} noop := &kv{key, []byte{}, false}
entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
var keys [][]byte var keys [][]byte
var vals [][]byte var vals [][]byte
@ -938,7 +939,7 @@ func decreaseKey(key []byte) []byte {
} }
func BenchmarkProve(b *testing.B) { func BenchmarkProve(b *testing.B) {
trie, vals := randomTrie(b, 100) trie, vals := randomTrie(100)
var keys []string var keys []string
for k := range vals { for k := range vals {
keys = append(keys, k) keys = append(keys, k)
@ -955,7 +956,7 @@ func BenchmarkProve(b *testing.B) {
} }
func BenchmarkVerifyProof(b *testing.B) { func BenchmarkVerifyProof(b *testing.B) {
trie, vals := randomTrie(b, 100) trie, vals := randomTrie(100)
root := trie.Hash() root := trie.Hash()
var keys []string var keys []string
var proofs []*memorydb.Database var proofs []*memorydb.Database
@ -981,10 +982,10 @@ func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b,
func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) } func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) }
func benchmarkVerifyRangeProof(b *testing.B, size int) { func benchmarkVerifyRangeProof(b *testing.B, size int) {
trie, vals := randomTrie(b, 8192) trie, vals := randomTrie(8192)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1018,10 +1019,10 @@ func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof
func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) } func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) }
func benchmarkVerifyRangeNoProof(b *testing.B, size int) { func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
trie, vals := randomTrie(b, size) trie, vals := randomTrie(size)
var entries entrySlice var entries entrySlice
for _, kv := range vals { for _, kv := range vals {
entries = append(entries, packEntry(kv)) entries = append(entries, kv)
} }
sort.Sort(entries) sort.Sort(entries)
@ -1040,24 +1041,56 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
} }
} }
func randomTrie(n int) (*Trie, map[string]*kv) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
for i := byte(0); i < 100; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
for i := 0; i < n; i++ {
value := &kv{randBytes(32), randBytes(20), false}
trie.Update(value.k, value.v)
vals[string(value.k)] = value
}
return trie, vals
}
func nonRandomTrie(n int) (*Trie, map[string]*kv) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := make(map[string]*kv)
max := uint64(0xffffffffffffffff)
for i := uint64(0); i < uint64(n); i++ {
value := make([]byte, 32)
key := make([]byte, 32)
binary.LittleEndian.PutUint64(key, i)
binary.LittleEndian.PutUint64(value, i-max)
//value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
elem := &kv{key, value, false}
trie.Update(elem.k, elem.v)
vals[string(elem.k)] = elem
}
return trie, vals
}
func TestRangeProofKeysWithSharedPrefix(t *testing.T) { func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
keys := [][]byte{ keys := [][]byte{
common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"), common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"),
common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"), common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"),
} }
vals := [][]byte{ vals := [][]byte{
packValue(2), common.Hex2Bytes("02"),
packValue(3), common.Hex2Bytes("03"),
} }
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
orig := geth_trie.NewEmpty(db)
for i, key := range keys { for i, key := range keys {
orig.Update(key, vals[i]) trie.Update(key, vals[i])
} }
root := commitTrie(t, db, orig) root := trie.Hash()
trie := indexTrie(t, edb, root)
proof := memorydb.New() proof := memorydb.New()
start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")
end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")

View File

@ -17,82 +17,88 @@
package trie package trie
import ( import (
"fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld" log "github.com/sirupsen/logrus"
) )
// StateTrie wraps a trie with key hashing. In a secure trie, all // StateTrie wraps a trie with key hashing. In a stateTrie trie, all
// access operations hash the key using keccak256. This prevents // access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that // calling code from creating long chains of nodes that
// increase the access time. // increase the access time.
// //
// Contrary to a regular trie, a StateTrie can only be created with // Contrary to a regular trie, a StateTrie can only be created with
// New and must have an attached database. // New and must have an attached database. The database also stores
// the preimage of each key if preimage recording is enabled.
// //
// StateTrie is not safe for concurrent use. // StateTrie is not safe for concurrent use.
type StateTrie struct { type StateTrie struct {
trie Trie trie Trie
preimages *preimageStore
hashKeyBuf [common.HashLength]byte hashKeyBuf [common.HashLength]byte
secKeyCache map[string][]byte
secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch
} }
// NewStateTrie creates a trie with an existing root node from a backing database // NewStateTrie creates a trie with an existing root node from a backing database.
// and optional intermediate in-memory node pool.
// //
// If root is the zero hash or the sha3 hash of an empty string, the // If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil // trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found. // and returns MissingNodeError if the root node cannot be found.
// func NewStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, error) {
// Accessing the trie loads nodes from the database or node pool on demand.
// Loaded nodes are kept around until their 'cache generation' expires.
// A new cache generation is created by each call to Commit.
// cachelimit sets the number of past cache generations to keep.
//
// Retrieves IPLD blocks by CID encoded as "eth-state-trie"
func NewStateTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
return newStateTrie(owner, root, db, ipld.MEthStateTrie)
}
// NewStorageTrie is identical to NewStateTrie, but retrieves IPLD blocks encoded
// as "eth-storage-trie"
func NewStorageTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
return newStateTrie(owner, root, db, ipld.MEthStorageTrie)
}
func newStateTrie(owner common.Hash, root common.Hash, db *Database, codec uint64) (*StateTrie, error) {
if db == nil { if db == nil {
panic("NewStateTrie called without a database") panic("trie.NewStateTrie called without a database")
} }
trie, err := New(owner, root, db, codec) trie, err := New(id, db, codec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &StateTrie{trie: *trie}, nil return &StateTrie{trie: *trie, preimages: db.preimages}, nil
}
// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *StateTrie) Get(key []byte) []byte {
res, err := t.TryGet(key)
if err != nil {
log.Error("Unhandled trie error in StateTrie.Get", "err", err)
}
return res
} }
// TryGet returns the value for key stored in the trie. // TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned. // If the specified node is not in the trie, nil will be returned.
// If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryGet(key []byte) ([]byte, error) { func (t *StateTrie) TryGet(key []byte) ([]byte, error) {
return t.trie.TryGet(t.hashKey(key)) return t.trie.TryGet(t.hashKey(key))
} }
func (t *StateTrie) TryGetAccount(key []byte) (*types.StateAccount, error) { // TryGetAccount attempts to retrieve an account with provided account address.
var ret types.StateAccount // If the specified account is not in the trie, nil will be returned.
res, err := t.TryGet(key) // If a trie node is not found in the database, a MissingNodeError is returned.
if err != nil { func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) {
// log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) res, err := t.trie.TryGet(t.hashKey(address.Bytes()))
panic(fmt.Sprintf("Unhandled trie error: %v", err)) if res == nil || err != nil {
return &ret, err return nil, err
} }
if res == nil { ret := new(types.StateAccount)
return nil, nil err = rlp.DecodeBytes(res, ret)
return ret, err
} }
err = rlp.DecodeBytes(res, &ret)
return &ret, err // TryGetAccountByHash does the same thing as TryGetAccount, however
// it expects an account hash that is the hash of address. This constitutes an
// abstraction leak, since the client code needs to know the key format.
func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
res, err := t.trie.TryGet(addrHash.Bytes())
if res == nil || err != nil {
return nil, err
}
ret := new(types.StateAccount)
err = rlp.DecodeBytes(res, ret)
return ret, err
} }
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
@ -103,12 +109,124 @@ func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) {
return t.trie.TryGetNode(path) return t.trie.TryGetNode(path)
} }
// Update associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
func (t *StateTrie) Update(key, value []byte) {
if err := t.TryUpdate(key, value); err != nil {
log.Error("Unhandled trie error in StateTrie.Update", "err", err)
}
}
// TryUpdate associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
// If a node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryUpdate(key, value []byte) error {
hk := t.hashKey(key)
err := t.trie.TryUpdate(hk, value)
if err != nil {
return err
}
t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
return nil
}
// TryUpdateAccount account will abstract the write of an account to the
// secure trie.
func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error {
hk := t.hashKey(address.Bytes())
data, err := rlp.EncodeToBytes(acc)
if err != nil {
return err
}
if err := t.trie.TryUpdate(hk, data); err != nil {
return err
}
t.getSecKeyCache()[string(hk)] = address.Bytes()
return nil
}
// Delete removes any existing value for key from the trie.
func (t *StateTrie) Delete(key []byte) {
if err := t.TryDelete(key); err != nil {
log.Error("Unhandled trie error in StateTrie.Delete", "err", err)
}
}
// TryDelete removes any existing value for key from the trie.
// If the specified trie node is not in the trie, nothing will be changed.
// If a node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryDelete(key []byte) error {
hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk)
}
// TryDeleteAccount abstracts an account deletion from the trie.
func (t *StateTrie) TryDeleteAccount(address common.Address) error {
hk := t.hashKey(address.Bytes())
delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk)
}
// GetKey returns the sha3 preimage of a hashed key that was
// previously used to store a value.
func (t *StateTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key
}
if t.preimages == nil {
return nil
}
return t.preimages.preimage(common.BytesToHash(shaKey))
}
// Commit collects all dirty nodes in the trie and replaces them with the
// corresponding node hash. All collected nodes (including dirty leaves if
// collectLeaf is true) will be encapsulated into a nodeset for return.
// The returned nodeset can be nil if the trie is clean (nothing to commit).
// All cached preimages will be also flushed if preimages recording is enabled.
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
// Write all the pre-images to the actual disk database
if len(t.getSecKeyCache()) > 0 {
if t.preimages != nil {
preimages := make(map[common.Hash][]byte)
for hk, key := range t.secKeyCache {
preimages[common.BytesToHash([]byte(hk))] = key
}
t.preimages.insertPreimage(preimages)
}
t.secKeyCache = make(map[string][]byte)
}
// Commit the trie and return its modified nodeset.
return t.trie.Commit(collectLeaf)
}
// Hash returns the root hash of StateTrie. It does not write to the // Hash returns the root hash of StateTrie. It does not write to the
// database and can be used even if the trie doesn't have one. // database and can be used even if the trie doesn't have one.
func (t *StateTrie) Hash() common.Hash { func (t *StateTrie) Hash() common.Hash {
return t.trie.Hash() return t.trie.Hash()
} }
// Copy returns a copy of StateTrie.
func (t *StateTrie) Copy() *StateTrie {
return &StateTrie{
trie: *t.trie.Copy(),
preimages: t.preimages,
secKeyCache: t.secKeyCache,
}
}
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key. // starts at the key after the given start key.
func (t *StateTrie) NodeIterator(start []byte) NodeIterator { func (t *StateTrie) NodeIterator(start []byte) NodeIterator {
@ -126,3 +244,14 @@ func (t *StateTrie) hashKey(key []byte) []byte {
returnHasherToPool(h) returnHasherToPool(h)
return t.hashKeyBuf[:] return t.hashKeyBuf[:]
} }
// getSecKeyCache returns the current secure key cache, creating a new one if
// ownership changed (i.e. the current secure trie is a copy of another owning
// the actual cache).
func (t *StateTrie) getSecKeyCache() map[string][]byte {
if t != t.secKeyCacheOwner {
t.secKeyCacheOwner = t
t.secKeyCache = make(map[string][]byte)
}
return t.secKeyCache
}

View File

@ -0,0 +1,148 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"bytes"
"fmt"
"runtime"
"sync"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
)
// makeTestStateTrie creates a large enough secure trie for testing.
func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
// Fill it with some arbitrary data
content := make(map[string][]byte)
for i := byte(0); i < 255; i++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
trie.Update(key, val)
}
}
root, nodes := trie.Commit(false)
if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err))
}
// Re-create the trie based on the new state
trie, _ = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
return triedb, trie, content
}
func TestSecureDelete(t *testing.T) {
trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
if err != nil {
t.Fatal(err)
}
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"ether", ""},
{"dog", "puppy"},
{"shaman", ""},
}
for _, val := range vals {
if val.v != "" {
trie.Update([]byte(val.k), []byte(val.v))
} else {
trie.Delete([]byte(val.k))
}
}
hash := trie.Hash()
exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestSecureGetKey(t *testing.T) {
trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
if err != nil {
t.Fatal(err)
}
trie.Update([]byte("foo"), []byte("bar"))
key := []byte("foo")
value := []byte("bar")
seckey := crypto.Keccak256(key)
if !bytes.Equal(trie.Get(key), value) {
t.Errorf("Get did not return bar")
}
if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
t.Errorf("GetKey returned %q, want %q", k, key)
}
}
func TestStateTrieConcurrency(t *testing.T) {
// Create an initial trie and copy if for concurrent access
_, trie, _ := makeTestStateTrie()
threads := runtime.NumCPU()
tries := make([]*StateTrie, threads)
for i := 0; i < threads; i++ {
tries[i] = trie.Copy()
}
// Start a batch of goroutines interacting with the trie
pend := new(sync.WaitGroup)
pend.Add(threads)
for i := 0; i < threads; i++ {
go func(index int) {
defer pend.Done()
for j := byte(0); j < 255; j++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
tries[index].Update(key, val)
key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
tries[index].Update(key, val)
// Add some other data to inflate the trie
for k := byte(3); k < 13; k++ {
key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
tries[index].Update(key, val)
}
}
tries[index].Commit(false)
}(i)
}
// Wait for all threads to finish
pend.Wait()
}

125
trie_by_cid/trie/tracer.go Normal file
View File

@ -0,0 +1,125 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import "github.com/ethereum/go-ethereum/common"
// tracer tracks the changes of trie nodes. During the trie operations,
// some nodes can be deleted from the trie, while these deleted nodes
// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted
// nodes won't be removed from the disk at all. Tracer is an auxiliary tool
// used to track all insert and delete operations of trie and capture all
// deleted nodes eventually.
//
// The changed nodes can be mainly divided into two categories: the leaf
// node and intermediate node. The former is inserted/deleted by callers
// while the latter is inserted/deleted in order to follow the rule of trie.
// This tool can track all of them no matter the node is embedded in its
// parent or not, but valueNode is never tracked.
//
// Besides, it's also used for recording the original value of the nodes
// when they are resolved from the disk. The pre-value of the nodes will
// be used to construct trie history in the future.
//
// Note tracer is not thread-safe, callers should be responsible for handling
// the concurrency issues by themselves.
type tracer struct {
inserts map[string]struct{}
deletes map[string]struct{}
accessList map[string][]byte
}
// newTracer initializes the tracer for capturing trie changes.
func newTracer() *tracer {
return &tracer{
inserts: make(map[string]struct{}),
deletes: make(map[string]struct{}),
accessList: make(map[string][]byte),
}
}
// onRead tracks the newly loaded trie node and caches the rlp-encoded
// blob internally. Don't change the value outside of function since
// it's not deep-copied.
func (t *tracer) onRead(path []byte, val []byte) {
t.accessList[string(path)] = val
}
// onInsert tracks the newly inserted trie node. If it's already
// in the deletion set (resurrected node), then just wipe it from
// the deletion set as it's "untouched".
func (t *tracer) onInsert(path []byte) {
if _, present := t.deletes[string(path)]; present {
delete(t.deletes, string(path))
return
}
t.inserts[string(path)] = struct{}{}
}
// onDelete tracks the newly deleted trie node. If it's already
// in the addition set, then just wipe it from the addition set
// as it's untouched.
func (t *tracer) onDelete(path []byte) {
if _, present := t.inserts[string(path)]; present {
delete(t.inserts, string(path))
return
}
t.deletes[string(path)] = struct{}{}
}
// reset clears the content tracked by tracer.
func (t *tracer) reset() {
t.inserts = make(map[string]struct{})
t.deletes = make(map[string]struct{})
t.accessList = make(map[string][]byte)
}
// copy returns a deep copied tracer instance.
func (t *tracer) copy() *tracer {
var (
inserts = make(map[string]struct{})
deletes = make(map[string]struct{})
accessList = make(map[string][]byte)
)
for path := range t.inserts {
inserts[path] = struct{}{}
}
for path := range t.deletes {
deletes[path] = struct{}{}
}
for path, blob := range t.accessList {
accessList[path] = common.CopyBytes(blob)
}
return &tracer{
inserts: inserts,
deletes: deletes,
accessList: accessList,
}
}
// markDeletions puts all tracked deletions into the provided nodeset.
func (t *tracer) markDeletions(set *NodeSet) {
for path := range t.deletes {
// It's possible a few deleted nodes were embedded
// in their parent before, the deletions can be no
// effect by deleting nothing, filter them out.
if _, ok := set.accessList[path]; !ok {
continue
}
set.markDeleted([]byte(path))
}
}

View File

@ -7,18 +7,15 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/core/types"
log "github.com/sirupsen/logrus"
util "github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld" "github.com/ethereum/go-ethereum/statediff/indexer/ipld"
) )
var ( var (
// emptyRoot is the known root hash of an empty trie. StateTrieCodec uint64 = ipld.MEthStateTrie
emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") StorageTrieCodec uint64 = ipld.MEthStorageTrie
// emptyState is the known hash of an empty state trie entry.
emptyState = crypto.Keccak256Hash(nil)
) )
// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
@ -37,29 +34,48 @@ type Trie struct {
// actually unhashed nodes. // actually unhashed nodes.
unhashed int unhashed int
// db is the handler trie can retrieve nodes from. It's // reader is the handler trie can retrieve nodes from.
// only for reading purpose and not available for writing. reader *trieReader
db *Database
// Multihash codec for key encoding // tracer is the tool to track the trie changes.
codec uint64 // It will be reset after each commit operation.
tracer *tracer
} }
// New creates a trie with an existing root node from db and an assigned // newFlag returns the cache flag value for a newly created node.
// owner for storage proximity. func (t *Trie) newFlag() nodeFlag {
// return nodeFlag{dirty: true}
// If root is the zero hash or the sha3 hash of an empty string, the }
// trie is initially empty and does not require a database. Otherwise,
// New will panic if db is nil and returns a MissingNodeError if root does // Copy returns a copy of Trie.
// not exist in the database. Accessing the trie loads nodes from db on demand. func (t *Trie) Copy() *Trie {
func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie, error) { return &Trie{
root: t.root,
owner: t.owner,
unhashed: t.unhashed,
reader: t.reader,
tracer: t.tracer.copy(),
}
}
// New creates a trie instance with the provided trie id and the read-only
// database. The state specified by trie id must be available, otherwise
// an error will be returned. The trie root specified by trie id can be
// zero hash or the sha3 hash of an empty string, then trie is initially
// empty, otherwise, the root node must be present in database or returns
// a MissingNodeError if not.
func New(id *ID, db NodeReader, codec uint64) (*Trie, error) {
reader, err := newTrieReader(id.StateRoot, id.Owner, db, codec)
if err != nil {
return nil, err
}
trie := &Trie{ trie := &Trie{
owner: owner, owner: id.Owner,
db: db, reader: reader,
codec: codec, tracer: newTracer(),
} }
if root != (common.Hash{}) && root != emptyRoot { if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash {
rootnode, err := trie.resolveHash(root[:], nil) rootnode, err := trie.resolveAndTrack(id.Root[:], nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +86,10 @@ func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie
// NewEmpty is a shortcut to create empty tree. It's mostly used in tests. // NewEmpty is a shortcut to create empty tree. It's mostly used in tests.
func NewEmpty(db *Database) *Trie { func NewEmpty(db *Database) *Trie {
tr, _ := New(common.Hash{}, common.Hash{}, db, ipld.MEthStateTrie) tr, err := New(TrieID(common.Hash{}), db, StateTrieCodec)
if err != nil {
panic(err)
}
return tr return tr
} }
@ -80,6 +99,16 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator {
return newNodeIterator(t, start) return newNodeIterator(t, start)
} }
// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *Trie) Get(key []byte) []byte {
res, err := t.TryGet(key)
if err != nil {
log.Error("Unhandled trie error in Trie.Get", "err", err)
}
return res
}
// TryGet returns the value for key stored in the trie. // TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller. // The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned. // If a node was not found in the database, a MissingNodeError is returned.
@ -116,7 +145,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
} }
return value, n, didResolve, err return value, n, didResolve, err
case hashNode: case hashNode:
child, err := t.resolveHash(n, key[:pos]) child, err := t.resolveAndTrack(n, key[:pos])
if err != nil { if err != nil {
return nil, n, true, err return nil, n, true, err
} }
@ -130,7 +159,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles. // possible to use keybyte-encoding as the path might contain odd nibbles.
func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
item, newroot, resolved, err := t.tryGetNode(t.root, CompactToHex(path), 0) item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0)
if err != nil { if err != nil {
return nil, resolved, err return nil, resolved, err
} }
@ -162,11 +191,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if hash == nil { if hash == nil {
return nil, origNode, 0, errors.New("non-consensus node") return nil, origNode, 0, errors.New("non-consensus node")
} }
cid, err := util.Keccak256ToCid(t.codec, hash) blob, err := t.reader.nodeBlob(path, common.BytesToHash(hash))
if err != nil {
return nil, origNode, 0, err
}
blob, err := t.db.Node(cid)
return blob, origNode, 1, err return blob, origNode, 1, err
} }
// Path still needs to be traversed, descend into children // Path still needs to be traversed, descend into children
@ -196,7 +221,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
return item, n, resolved, err return item, n, resolved, err
case hashNode: case hashNode:
child, err := t.resolveHash(n, path[:pos]) child, err := t.resolveAndTrack(n, path[:pos])
if err != nil { if err != nil {
return nil, n, 1, err return nil, n, 1, err
} }
@ -208,42 +233,309 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
} }
} }
// resolveHash loads node from the underlying database with the provided // Update associates key with value in the trie. Subsequent calls to
// node hash and path prefix. // Get will return value. If value has length zero, any existing value
func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { // is deleted from the trie and calls to Get will return nil.
cid, err := util.Keccak256ToCid(t.codec, n) //
if err != nil { // The value bytes must not be modified by the caller while they are
return nil, err // stored in the trie.
func (t *Trie) Update(key, value []byte) {
if err := t.TryUpdate(key, value); err != nil {
log.Error("Unhandled trie error in Trie.Update", "err", err)
} }
enc, err := t.db.Node(cid)
if err != nil {
return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix, err: err}
}
node, err := decodeNodeUnsafe(n, enc)
if err != nil {
return nil, err
}
if node != nil {
return node, nil
}
return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix}
} }
// resolveHash loads rlp-encoded node blob from the underlying database // TryUpdate associates key with value in the trie. Subsequent calls to
// with the provided node hash and path prefix. // Get will return value. If value has length zero, any existing value
func (t *Trie) resolveBlob(n hashNode, prefix []byte) ([]byte, error) { // is deleted from the trie and calls to Get will return nil.
cid, err := util.Keccak256ToCid(t.codec, n) //
// The value bytes must not be modified by the caller while they are
// stored in the trie.
//
// If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryUpdate(key, value []byte) error {
return t.tryUpdate(key, value)
}
// tryUpdate expects an RLP-encoded value and performs the core function
// for TryUpdate and TryUpdateAccount.
func (t *Trie) tryUpdate(key, value []byte) error {
t.unhashed++
k := keybytesToHex(key)
if len(value) != 0 {
_, n, err := t.insert(t.root, nil, k, valueNode(value))
if err != nil {
return err
}
t.root = n
} else {
_, n, err := t.delete(t.root, nil, k)
if err != nil {
return err
}
t.root = n
}
return nil
}
func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) {
if len(key) == 0 {
if v, ok := n.(valueNode); ok {
return !bytes.Equal(v, value.(valueNode)), value, nil
}
return true, value, nil
}
switch n := n.(type) {
case *shortNode:
matchlen := prefixLen(key, n.Key)
// If the whole key matches, keep this short node as is
// and only update the value.
if matchlen == len(n.Key) {
dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
if !dirty || err != nil {
return false, n, err
}
return true, &shortNode{n.Key, nn, t.newFlag()}, nil
}
// Otherwise branch out at the index where they differ.
branch := &fullNode{flags: t.newFlag()}
var err error
_, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
if err != nil {
return false, nil, err
}
_, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value)
if err != nil {
return false, nil, err
}
// Replace this shortNode with the branch if it occurs at index 0.
if matchlen == 0 {
return true, branch, nil
}
// New branch node is created as a child of the original short node.
// Track the newly inserted node in the tracer. The node identifier
// passed is the path from the root node.
t.tracer.onInsert(append(prefix, key[:matchlen]...))
// Replace it with a short node leading up to the branch.
return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
case *fullNode:
dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
if !dirty || err != nil {
return false, n, err
}
n = n.copy()
n.flags = t.newFlag()
n.Children[key[0]] = nn
return true, n, nil
case nil:
// New short node is created and track it in the tracer. The node identifier
// passed is the path from the root node. Note the valueNode won't be tracked
// since it's always embedded in its parent.
t.tracer.onInsert(prefix)
return true, &shortNode{key, value, t.newFlag()}, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
// the node and insert into it. This leaves all child nodes on
// the path to the value in the trie.
rn, err := t.resolveAndTrack(n, prefix)
if err != nil {
return false, nil, err
}
dirty, nn, err := t.insert(rn, prefix, key, value)
if !dirty || err != nil {
return false, rn, err
}
return true, nn, nil
default:
panic(fmt.Sprintf("%T: invalid node: %v", n, n))
}
}
// Delete removes any existing value for key from the trie.
func (t *Trie) Delete(key []byte) {
if err := t.TryDelete(key); err != nil {
log.Error("Unhandled trie error in Trie.Delete", "err", err)
}
}
// TryDelete removes any existing value for key from the trie.
// If a node was not found in the database, a MissingNodeError is returned.
func (t *Trie) TryDelete(key []byte) error {
t.unhashed++
k := keybytesToHex(key)
_, n, err := t.delete(t.root, nil, k)
if err != nil {
return err
}
t.root = n
return nil
}
// delete returns the new root of the trie with key deleted.
// It reduces the trie to minimal form by simplifying
// nodes on the way up after deleting recursively.
func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
switch n := n.(type) {
case *shortNode:
matchlen := prefixLen(key, n.Key)
if matchlen < len(n.Key) {
return false, n, nil // don't replace n on mismatch
}
if matchlen == len(key) {
// The matched short node is deleted entirely and track
// it in the deletion set. The same the valueNode doesn't
// need to be tracked at all since it's always embedded.
t.tracer.onDelete(prefix)
return true, nil, nil // remove n entirely for whole matches
}
// The key is longer than n.Key. Remove the remaining suffix
// from the subtrie. Child can never be nil here since the
// subtrie must contain at least two other values with keys
// longer than n.Key.
dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
if !dirty || err != nil {
return false, n, err
}
switch child := child.(type) {
case *shortNode:
// The child shortNode is merged into its parent, track
// is deleted as well.
t.tracer.onDelete(append(prefix, n.Key...))
// Deleting from the subtrie reduced it to another
// short node. Merge the nodes to avoid creating a
// shortNode{..., shortNode{...}}. Use concat (which
// always creates a new slice) instead of append to
// avoid modifying n.Key since it might be shared with
// other nodes.
return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil
default:
return true, &shortNode{n.Key, child, t.newFlag()}, nil
}
case *fullNode:
dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
if !dirty || err != nil {
return false, n, err
}
n = n.copy()
n.flags = t.newFlag()
n.Children[key[0]] = nn
// Because n is a full node, it must've contained at least two children
// before the delete operation. If the new child value is non-nil, n still
// has at least two children after the deletion, and cannot be reduced to
// a short node.
if nn != nil {
return true, n, nil
}
// Reduction:
// Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is
// left. Since n must've contained at least two children
// before deletion (otherwise it would not be a full node) n
// can never be reduced to nil.
//
// When the loop is done, pos contains the index of the single
// value that is left in n or -2 if n contains at least two
// values.
pos := -1
for i, cld := range &n.Children {
if cld != nil {
if pos == -1 {
pos = i
} else {
pos = -2
break
}
}
}
if pos >= 0 {
if pos != 16 {
// If the remaining entry is a short node, it replaces
// n and its key gets the missing nibble tacked to the
// front. This avoids creating an invalid
// shortNode{..., shortNode{...}}. Since the entry
// might not be loaded yet, resolve it just for this
// check.
cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos)))
if err != nil {
return false, nil, err
}
if cnode, ok := cnode.(*shortNode); ok {
// Replace the entire full node with the short node.
// Mark the original short node as deleted since the
// value is embedded into the parent now.
t.tracer.onDelete(append(prefix, byte(pos)))
k := append([]byte{byte(pos)}, cnode.Key...)
return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
}
}
// Otherwise, n is replaced by a one-nibble short node
// containing the child.
return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
}
// n still contains at least two values and cannot be reduced.
return true, n, nil
case valueNode:
return true, nil, nil
case nil:
return false, nil, nil
case hashNode:
// We've hit a part of the trie that isn't loaded yet. Load
// the node and delete from it. This leaves all child nodes on
// the path to the value in the trie.
rn, err := t.resolveAndTrack(n, prefix)
if err != nil {
return false, nil, err
}
dirty, nn, err := t.delete(rn, prefix, key)
if !dirty || err != nil {
return false, rn, err
}
return true, nn, nil
default:
panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
}
}
func concat(s1 []byte, s2 ...byte) []byte {
r := make([]byte, len(s1)+len(s2))
copy(r, s1)
copy(r[len(s1):], s2)
return r
}
func (t *Trie) resolve(n node, prefix []byte) (node, error) {
if n, ok := n.(hashNode); ok {
return t.resolveAndTrack(n, prefix)
}
return n, nil
}
// resolveAndTrack loads node from the underlying store with the given node hash
// and path prefix and also tracks the loaded node blob in tracer treated as the
// node's original value. The rlp-encoded blob is preferred to be loaded from
// database because it's easy to decode node while complex to encode node to blob.
func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
blob, err := t.reader.nodeBlob(prefix, common.BytesToHash(n))
if err != nil { if err != nil {
return nil, err return nil, err
} }
blob, err := t.db.Node(cid) t.tracer.onRead(prefix, blob)
if err != nil { return mustDecodeNode(n, blob), nil
return nil, err
}
if len(blob) != 0 {
return blob, nil
}
return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix}
} }
// Hash returns the root hash of the trie. It does not write to the // Hash returns the root hash of the trie. It does not write to the
@ -254,15 +546,59 @@ func (t *Trie) Hash() common.Hash {
return common.BytesToHash(hash.(hashNode)) return common.BytesToHash(hash.(hashNode))
} }
// Commit collects all dirty nodes in the trie and replaces them with the
// corresponding node hash. All collected nodes (including dirty leaves if
// collectLeaf is true) will be encapsulated into a nodeset for return.
// The returned nodeset can be nil if the trie is clean (nothing to commit).
// Once the trie is committed, it's not usable anymore. A new trie must
// be created with new root and updated trie database for following usage
func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
defer t.tracer.reset()
nodes := NewNodeSet(t.owner, t.tracer.accessList)
t.tracer.markDeletions(nodes)
// Trie is empty and can be classified into two types of situations:
// - The trie was empty and no update happens
// - The trie was non-empty and all nodes are dropped
if t.root == nil {
return types.EmptyRootHash, nodes
}
// Derive the hash for all dirty nodes first. We hold the assumption
// in the following procedure that all nodes are hashed.
rootHash := t.Hash()
// Do a quick check if we really need to commit. This can happen e.g.
// if we load a trie for reading storage values, but don't write to it.
if hashedNode, dirty := t.root.cache(); !dirty {
// Replace the root node with the origin hash in order to
// ensure all resolved nodes are dropped after the commit.
t.root = hashedNode
return rootHash, nil
}
t.root = newCommitter(nodes, collectLeaf).Commit(t.root)
return rootHash, nodes
}
// hashRoot calculates the root hash of the given trie // hashRoot calculates the root hash of the given trie
func (t *Trie) hashRoot() (node, node) { func (t *Trie) hashRoot() (node, node) {
if t.root == nil { if t.root == nil {
return hashNode(emptyRoot.Bytes()), nil return hashNode(types.EmptyRootHash.Bytes()), nil
} }
// If the number of changes is below 100, we let one thread handle it // If the number of changes is below 100, we let one thread handle it
h := newHasher(t.unhashed >= 100) h := newHasher(t.unhashed >= 100)
defer returnHasherToPool(h) defer func() {
hashed, cached := h.hash(t.root, true) returnHasherToPool(h)
t.unhashed = 0 t.unhashed = 0
}()
hashed, cached := h.hash(t.root, true)
return hashed, cached return hashed, cached
} }
// Reset drops the referenced root node and cleans all internal state.
func (t *Trie) Reset() {
t.root = nil
t.owner = common.Hash{}
t.unhashed = 0
t.tracer.reset()
}

View File

@ -0,0 +1,55 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>
package trie
import "github.com/ethereum/go-ethereum/common"
// ID is the identifier for uniquely identifying a trie.
type ID struct {
StateRoot common.Hash // The root of the corresponding state(block.root)
Owner common.Hash // The contract address hash which the trie belongs to
Root common.Hash // The root hash of trie
}
// StateTrieID constructs an identifier for state trie with the provided state root.
func StateTrieID(root common.Hash) *ID {
return &ID{
StateRoot: root,
Owner: common.Hash{},
Root: root,
}
}
// StorageTrieID constructs an identifier for storage trie which belongs to a certain
// state and contract specified by the stateRoot and owner.
func StorageTrieID(stateRoot common.Hash, owner common.Hash, root common.Hash) *ID {
return &ID{
StateRoot: stateRoot,
Owner: owner,
Root: root,
}
}
// TrieID constructs an identifier for a standard trie(not a second-layer trie)
// with provided root. It's mostly used in tests and some other tries like CHT trie.
func TrieID(root common.Hash) *ID {
return &ID{
StateRoot: root,
Owner: common.Hash{},
Root: root,
}
}

View File

@ -0,0 +1,106 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie
import (
"fmt"
"github.com/ethereum/go-ethereum/common"
)
// Reader wraps the Node and NodeBlob method of a backing trie store.
type Reader interface {
// Node retrieves the trie node with the provided trie identifier, hexary
// node path and the corresponding node hash.
// No error will be returned if the node is not found.
Node(owner common.Hash, path []byte, hash common.Hash) (node, error)
// NodeBlob retrieves the RLP-encoded trie node blob with the provided trie
// identifier, hexary node path and the corresponding node hash.
// No error will be returned if the node is not found.
NodeBlob(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
}
// NodeReader wraps all the necessary functions for accessing trie node.
type NodeReader interface {
// GetReader returns a reader for accessing all trie nodes with provided
// state root. Nil is returned in case the state is not available.
GetReader(root common.Hash, codec uint64) Reader
}
// trieReader is a wrapper of the underlying node reader. It's not safe
// for concurrent usage.
type trieReader struct {
owner common.Hash
reader Reader
banned map[string]struct{} // Marker to prevent node from being accessed, for tests
}
// newTrieReader initializes the trie reader with the given node reader.
func newTrieReader(stateRoot, owner common.Hash, db NodeReader, codec uint64) (*trieReader, error) {
reader := db.GetReader(stateRoot, codec)
if reader == nil {
return nil, fmt.Errorf("state not found #%x", stateRoot)
}
return &trieReader{owner: owner, reader: reader}, nil
}
// newEmptyReader initializes the pure in-memory reader. All read operations
// should be forbidden and returns the MissingNodeError.
func newEmptyReader() *trieReader {
return &trieReader{}
}
// node retrieves the trie node with the provided trie node information.
// An MissingNodeError will be returned in case the node is not found or
// any error is encountered.
func (r *trieReader) node(path []byte, hash common.Hash) (node, error) {
// Perform the logics in tests for preventing trie node access.
if r.banned != nil {
if _, ok := r.banned[string(path)]; ok {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
}
if r.reader == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
node, err := r.reader.Node(r.owner, path, hash)
if err != nil || node == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
}
return node, nil
}
// node retrieves the rlp-encoded trie node with the provided trie node
// information. An MissingNodeError will be returned in case the node is
// not found or any error is encountered.
func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) {
// Perform the logics in tests for preventing trie node access.
if r.banned != nil {
if _, ok := r.banned[string(path)]; ok {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
}
if r.reader == nil {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
}
blob, err := r.reader.NodeBlob(r.owner, path, hash)
if err != nil || len(blob) == 0 {
return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
}
return blob, nil
}

View File

@ -14,26 +14,34 @@
// You should have received a copy of the GNU Lesser General Public License // You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package trie_test package trie
import ( import (
"bytes"
"encoding/binary"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"math/rand" "math/rand"
"reflect"
"testing" "testing"
"testing/quick"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
geth_trie "github.com/ethereum/go-ethereum/trie"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
) )
func TestTrieEmpty(t *testing.T) { func init() {
trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) spew.Config.Indent = " "
spew.Config.DisableMethods = false
}
func TestEmptyTrie(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
res := trie.Hash() res := trie.Hash()
exp := types.EmptyRootHash exp := types.EmptyRootHash
if res != exp { if res != exp {
@ -41,42 +49,543 @@ func TestTrieEmpty(t *testing.T) {
} }
} }
func TestTrieMissingRoot(t *testing.T) { func TestNull(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
key := make([]byte, 32)
value := []byte("test")
trie.Update(key, value)
if !bytes.Equal(trie.Get(key), value) {
t.Fatal("wrong value")
}
}
func TestMissingRoot(t *testing.T) {
root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33") root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33")
tr, err := newStateTrie(root, trie.NewDatabase(rawdb.NewMemoryDatabase())) trie, err := NewAccountTrie(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase()))
if tr != nil { if trie != nil {
t.Error("New returned non-nil trie for invalid root") t.Error("New returned non-nil trie for invalid root")
} }
if _, ok := err.(*trie.MissingNodeError); !ok { if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("New returned wrong error: %v", err) t.Errorf("New returned wrong error: %v", err)
} }
} }
func TestTrieBasic(t *testing.T) { func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
edb := rawdb.NewMemoryDatabase()
db := geth_trie.NewDatabase(edb) func testMissingNode(t *testing.T, memonly bool) {
origtrie := geth_trie.NewEmpty(db) diskdb := rawdb.NewMemoryDatabase()
origtrie.Update([]byte("foo"), packValue(842)) triedb := NewDatabase(diskdb)
expected := commitTrie(t, db, origtrie)
tr := indexTrie(t, edb, expected) trie := NewEmpty(triedb)
got := tr.Hash() updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
if expected != got { updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
t.Errorf("got %x expected %x", got, expected) root, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err := trie.TryGet([]byte("120000"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
} }
checkValue(t, tr, []byte("foo")) trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120099"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryDelete([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
} }
func TestTrieTiny(t *testing.T) { hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
if memonly {
delete(triedb.dirties, hash)
} else {
diskdb.Delete(hash[:])
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120000"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("120099"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
_, err = trie.TryGet([]byte("123456"))
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
trie, _ = NewAccountTrie(TrieID(root), triedb)
err = trie.TryDelete([]byte("123456"))
if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("Wrong error: %v", err)
}
}
func TestInsert(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "doe", "reindeer")
updateString(trie, "dog", "puppy")
updateString(trie, "dogglesworth", "cat")
exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
root := trie.Hash()
if root != exp {
t.Errorf("case 1: exp %x got %x", exp, root)
}
trie = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
root, _ = trie.Commit(false)
if root != exp {
t.Errorf("case 2: exp %x got %x", exp, root)
}
}
func TestGet(t *testing.T) {
db := NewDatabase(rawdb.NewMemoryDatabase())
trie := NewEmpty(db)
updateString(trie, "doe", "reindeer")
updateString(trie, "dog", "puppy")
updateString(trie, "dogglesworth", "cat")
for i := 0; i < 2; i++ {
res := getString(trie, "dog")
if !bytes.Equal(res, []byte("puppy")) {
t.Errorf("expected puppy got %x", res)
}
unknown := getString(trie, "unknown")
if unknown != nil {
t.Errorf("expected nil got %x", unknown)
}
if i == 1 {
return
}
root, nodes := trie.Commit(false)
db.Update(NewWithNodeSet(nodes))
trie, _ = NewAccountTrie(TrieID(root), db)
}
}
func TestDelete(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"ether", ""},
{"dog", "puppy"},
{"shaman", ""},
}
for _, val := range vals {
if val.v != "" {
updateString(trie, val.k, val.v)
} else {
deleteString(trie, val.k)
}
}
hash := trie.Hash()
exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestEmptyValues(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"ether", ""},
{"dog", "puppy"},
{"shaman", ""},
}
for _, val := range vals {
updateString(trie, val.k, val.v)
}
hash := trie.Hash()
exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if hash != exp {
t.Errorf("expected %x got %x", exp, hash)
}
}
func TestReplication(t *testing.T) {
triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie := NewEmpty(triedb)
vals := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
{"shaman", "horse"},
{"doge", "coin"},
{"dog", "puppy"},
{"somethingveryoddindeedthis is", "myothernodedata"},
}
for _, val := range vals {
updateString(trie, val.k, val.v)
}
exp, nodes := trie.Commit(false)
triedb.Update(NewWithNodeSet(nodes))
// create a new trie on top of the database and check that lookups work.
trie2, err := NewAccountTrie(TrieID(exp), triedb)
if err != nil {
t.Fatalf("can't recreate trie at %x: %v", exp, err)
}
for _, kv := range vals {
if string(getString(trie2, kv.k)) != kv.v {
t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
}
}
hash, nodes := trie2.Commit(false)
if hash != exp {
t.Errorf("root failure. expected %x got %x", exp, hash)
}
// recreate the trie after commit
if nodes != nil {
triedb.Update(NewWithNodeSet(nodes))
}
trie2, err = NewAccountTrie(TrieID(hash), triedb)
if err != nil {
t.Fatalf("can't recreate trie at %x: %v", exp, err)
}
// perform some insertions on the new trie.
vals2 := []struct{ k, v string }{
{"do", "verb"},
{"ether", "wookiedoo"},
{"horse", "stallion"},
// {"shaman", "horse"},
// {"doge", "coin"},
// {"ether", ""},
// {"dog", "puppy"},
// {"somethingveryoddindeedthis is", "myothernodedata"},
// {"shaman", ""},
}
for _, val := range vals2 {
updateString(trie2, val.k, val.v)
}
if hash := trie2.Hash(); hash != exp {
t.Errorf("root failure. expected %x got %x", exp, hash)
}
}
func TestLargeValue(t *testing.T) {
trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
trie.Hash()
}
// TestRandomCases tests some cases that were found via random fuzzing
func TestRandomCases(t *testing.T) {
var rt = []randTestStep{
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 0
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 1
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000002")}, // step 2
{op: 2, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 3
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 4
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 5
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 6
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 7
{op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000008")}, // step 8
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000009")}, // step 9
{op: 2, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 10
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 11
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 12
{op: 0, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("000000000000000d")}, // step 13
{op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 14
{op: 1, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 15
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 16
{op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000011")}, // step 17
{op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 18
{op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 19
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000014")}, // step 20
{op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000015")}, // step 21
{op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000016")}, // step 22
{op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 23
{op: 1, key: common.Hex2Bytes("980c393656413a15c8da01978ed9f89feb80b502f58f2d640e3a2f5f7a99a7018f1b573befd92053ac6f78fca4a87268"), value: common.Hex2Bytes("")}, // step 24
{op: 1, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 25
}
runRandTest(rt)
}
// randTest performs random trie operations.
// Instances of this test are created by Generate.
type randTest []randTestStep
type randTestStep struct {
op int
key []byte // for opUpdate, opDelete, opGet
value []byte // for opUpdate
err error // for debugging
}
const (
opUpdate = iota
opDelete
opGet
opHash
opCommit
opItercheckhash
opNodeDiff
opProve
opMax // boundary value, not an actual op
)
func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
var allKeys [][]byte
genKey := func() []byte {
if len(allKeys) < 2 || r.Intn(100) < 10 {
// new key
key := make([]byte, r.Intn(50))
r.Read(key)
allKeys = append(allKeys, key)
return key
}
// use existing key
return allKeys[r.Intn(len(allKeys))]
}
var steps randTest
for i := 0; i < size; i++ {
step := randTestStep{op: r.Intn(opMax)}
switch step.op {
case opUpdate:
step.key = genKey()
step.value = make([]byte, 8)
binary.BigEndian.PutUint64(step.value, uint64(i))
case opGet, opDelete, opProve:
step.key = genKey()
}
steps = append(steps, step)
}
return reflect.ValueOf(steps)
}
func verifyAccessList(old *Trie, new *Trie, set *NodeSet) error {
deletes, inserts, updates := diffTries(old, new)
// Check insertion set
for path := range inserts {
n, ok := set.nodes[path]
if !ok || n.isDeleted() {
return errors.New("expect new node")
}
_, ok = set.accessList[path]
if ok {
return errors.New("unexpected origin value")
}
}
// Check deletion set
for path, blob := range deletes {
n, ok := set.nodes[path]
if !ok || !n.isDeleted() {
return errors.New("expect deleted node")
}
v, ok := set.accessList[path]
if !ok {
return errors.New("expect origin value")
}
if !bytes.Equal(v, blob) {
return errors.New("invalid origin value")
}
}
// Check update set
for path, blob := range updates {
n, ok := set.nodes[path]
if !ok || n.isDeleted() {
return errors.New("expect updated node")
}
v, ok := set.accessList[path]
if !ok {
return errors.New("expect origin value")
}
if !bytes.Equal(v, blob) {
return errors.New("invalid origin value")
}
}
return nil
}
func runRandTest(rt randTest) bool {
var (
triedb = NewDatabase(rawdb.NewMemoryDatabase())
tr = NewEmpty(triedb)
values = make(map[string]string) // tracks content of the trie
origTrie = NewEmpty(triedb)
)
for i, step := range rt {
// fmt.Printf("{op: %d, key: common.Hex2Bytes(\"%x\"), value: common.Hex2Bytes(\"%x\")}, // step %d\n",
// step.op, step.key, step.value, i)
switch step.op {
case opUpdate:
tr.Update(step.key, step.value)
values[string(step.key)] = string(step.value)
case opDelete:
tr.Delete(step.key)
delete(values, string(step.key))
case opGet:
v := tr.Get(step.key)
want := values[string(step.key)]
if string(v) != want {
rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want)
}
case opProve:
hash := tr.Hash()
if hash == types.EmptyRootHash {
continue
}
proofDb := rawdb.NewMemoryDatabase()
err := tr.Prove(step.key, 0, proofDb)
if err != nil {
rt[i].err = fmt.Errorf("failed for proving key %#x, %v", step.key, err)
}
_, err = VerifyProof(hash, step.key, proofDb)
if err != nil {
rt[i].err = fmt.Errorf("failed for verifying key %#x, %v", step.key, err)
}
case opHash:
tr.Hash()
case opCommit:
root, nodes := tr.Commit(true)
if nodes != nil {
triedb.Update(NewWithNodeSet(nodes))
}
newtr, err := NewAccountTrie(TrieID(root), triedb)
if err != nil {
rt[i].err = err
return false
}
if nodes != nil {
if err := verifyAccessList(origTrie, newtr, nodes); err != nil {
rt[i].err = err
return false
}
}
tr = newtr
origTrie = tr.Copy()
case opItercheckhash:
checktr := NewEmpty(triedb)
it := NewIterator(tr.NodeIterator(nil))
for it.Next() {
checktr.Update(it.Key, it.Value)
}
if tr.Hash() != checktr.Hash() {
rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
}
case opNodeDiff:
var (
origIter = origTrie.NodeIterator(nil)
curIter = tr.NodeIterator(nil)
origSeen = make(map[string]struct{})
curSeen = make(map[string]struct{})
)
for origIter.Next(true) {
if origIter.Leaf() {
continue
}
origSeen[string(origIter.Path())] = struct{}{}
}
for curIter.Next(true) {
if curIter.Leaf() {
continue
}
curSeen[string(curIter.Path())] = struct{}{}
}
var (
insertExp = make(map[string]struct{})
deleteExp = make(map[string]struct{})
)
for path := range curSeen {
_, present := origSeen[path]
if !present {
insertExp[path] = struct{}{}
}
}
for path := range origSeen {
_, present := curSeen[path]
if !present {
deleteExp[path] = struct{}{}
}
}
if len(insertExp) != len(tr.tracer.inserts) {
rt[i].err = fmt.Errorf("insert set mismatch")
}
if len(deleteExp) != len(tr.tracer.deletes) {
rt[i].err = fmt.Errorf("delete set mismatch")
}
for insert := range tr.tracer.inserts {
if _, present := insertExp[insert]; !present {
rt[i].err = fmt.Errorf("missing inserted node")
}
}
for del := range tr.tracer.deletes {
if _, present := deleteExp[del]; !present {
rt[i].err = fmt.Errorf("missing deleted node")
}
}
}
// Abort the test on error.
if rt[i].err != nil {
return false
}
}
return true
}
func TestRandom(t *testing.T) {
if err := quick.Check(runRandTest, nil); err != nil {
if cerr, ok := err.(*quick.CheckError); ok {
t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
}
t.Fatal(err)
}
}
func TestTinyTrie(t *testing.T) {
// Create a realistic account trie to hash // Create a realistic account trie to hash
_, accounts := makeAccounts(5) _, accounts := makeAccounts(5)
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
origtrie := geth_trie.NewEmpty(db)
type testCase struct { type testCase struct {
key, account []byte key, account []byte
root common.Hash root common.Hash
} }
cases := []testCase{ cases := []testCase{
{ {
common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"),
@ -92,48 +601,48 @@ func TestTrieTiny(t *testing.T) {
common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"),
}, },
} }
for i, tc := range cases { for i, c := range cases {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { trie.Update(c.key, c.account)
origtrie.Update(tc.key, tc.account) root := trie.Hash()
trie := indexTrie(t, edb, commitTrie(t, db, origtrie)) if root != c.root {
if exp, root := tc.root, trie.Hash(); exp != root { t.Errorf("case %d: got %x, exp %x", i, root, c.root)
t.Errorf("got %x, exp %x", root, exp)
} }
checkValue(t, trie, tc.key) }
}) checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
it := NewIterator(trie.NodeIterator(nil))
for it.Next() {
checktr.Update(it.Key, it.Value)
}
if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot {
t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot)
} }
} }
func TestTrieMedium(t *testing.T) { func TestCommitAfterHash(t *testing.T) {
// Create a realistic account trie to hash // Create a realistic account trie to hash
addresses, accounts := makeAccounts(1000) addresses, accounts := makeAccounts(1000)
edb := rawdb.NewMemoryDatabase() trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
db := geth_trie.NewDatabase(edb)
origtrie := geth_trie.NewEmpty(db)
var keys [][]byte
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
key := crypto.Keccak256(addresses[i][:]) trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
if i%50 == 0 {
keys = append(keys, key)
} }
origtrie.Update(key, accounts[i]) // Insert the accounts into the trie and hash it
} trie.Hash()
tr := indexTrie(t, edb, commitTrie(t, db, origtrie)) trie.Commit(false)
root := trie.Hash()
root := tr.Hash()
exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6") exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6")
if exp != root { if exp != root {
t.Errorf("got %x, exp %x", root, exp) t.Errorf("got %x, exp %x", root, exp)
} }
root, _ = trie.Commit(false)
for _, key := range keys { if exp != root {
checkValue(t, tr, key) t.Errorf("got %x, exp %x", root, exp)
} }
} }
// Make deterministically random accounts
func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
// Make the random benchmark deterministic
random := rand.New(rand.NewSource(0)) random := rand.New(rand.NewSource(0))
// Create a realistic account trie to hash
addresses = make([][20]byte, size) addresses = make([][20]byte, size)
for i := 0; i < len(addresses); i++ { for i := 0; i < len(addresses); i++ {
data := make([]byte, 20) data := make([]byte, 20)
@ -155,19 +664,34 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
balanceBytes := make([]byte, numBytes) balanceBytes := make([]byte, numBytes)
random.Read(balanceBytes) random.Read(balanceBytes)
balance := new(big.Int).SetBytes(balanceBytes) balance := new(big.Int).SetBytes(balanceBytes)
acct := &types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code} data, _ := rlp.EncodeToBytes(&types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code})
data, _ := rlp.EncodeToBytes(acct)
accounts[i] = data accounts[i] = data
} }
return addresses, accounts return addresses, accounts
} }
func checkValue(t *testing.T, tr *trie.Trie, key []byte) { func getString(trie *Trie, k string) []byte {
val, err := tr.TryGet(key) return trie.Get([]byte(k))
if err != nil {
t.Fatalf("error getting node: %s", err)
} }
if len(val) == 0 {
t.Errorf("failed to get value for %x", key) func updateString(trie *Trie, k, v string) {
trie.Update([]byte(k), []byte(v))
}
func deleteString(trie *Trie, k string) {
trie.Delete([]byte(k))
}
func TestDecodeNode(t *testing.T) {
t.Parallel()
var (
hash = make([]byte, 20)
elems = make([]byte, 20)
)
for i := 0; i < 5000000; i++ {
prng.Read(hash)
prng.Read(elems)
decodeNode(hash, elems)
} }
} }

View File

@ -1,43 +1,133 @@
package trie_test package trie
import ( import (
"bytes"
"context"
"fmt" "fmt"
"math/big" "math/big"
"math/rand" "math/rand"
"testing" "testing"
"time"
"github.com/jmoiron/sqlx"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/rawdb"
geth_state "github.com/ethereum/go-ethereum/core/state" gethstate "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
geth_trie "github.com/ethereum/go-ethereum/trie" gethtrie "github.com/ethereum/go-ethereum/trie"
"github.com/jmoiron/sqlx"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/state"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld"
"github.com/ethereum/go-ethereum/statediff/test_helpers" "github.com/ethereum/go-ethereum/statediff/test_helpers"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
) )
type kv struct { var (
dbConfig, _ = postgres.DefaultConfig.WithEnv()
trieConfig = Config{Cache: 256}
)
type kvi struct {
k []byte k []byte
v int64 v int64
} }
type kvMap map[string]*kv type kvMap map[string]*kvi
type kvs struct { type kvsi struct {
k string k string
v int64 v int64
} }
// NewAccountTrie is a shortcut to create a trie using the StateTrieCodec (ie. IPLD MEthStateTrie codec).
func NewAccountTrie(id *ID, db NodeReader) (*Trie, error) {
return New(id, db, StateTrieCodec)
}
// makeTestTrie create a sample test trie to test node-wise reconstruction.
func makeTestTrie(t testing.TB) (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie
triedb := NewDatabase(rawdb.NewMemoryDatabase())
trie, err := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
// Fill it with some arbitrary data
content := make(map[string][]byte)
for i := byte(0); i < 255; i++ {
// Map the same data under multiple keys
key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
content[string(key)] = val
trie.Update(key, val)
// Add some other data to inflate the trie
for j := byte(3); j < 13; j++ {
key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
content[string(key)] = val
trie.Update(key, val)
}
}
root, nodes := trie.Commit(false)
if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
panic(fmt.Errorf("failed to commit db %v", err))
}
// Re-create the trie based on the new state
trie, err = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
if err != nil {
t.Fatal(err)
}
return triedb, trie, content
}
func forHashedNodes(tr *Trie) map[string][]byte {
var (
it = tr.NodeIterator(nil)
nodes = make(map[string][]byte)
)
for it.Next(true) {
if it.Hash() == (common.Hash{}) {
continue
}
nodes[string(it.Path())] = common.CopyBytes(it.NodeBlob())
}
return nodes
}
func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) {
var (
nodesA = forHashedNodes(trieA)
nodesB = forHashedNodes(trieB)
inA = make(map[string][]byte) // hashed nodes in trie a but not b
inB = make(map[string][]byte) // hashed nodes in trie b but not a
both = make(map[string][]byte) // hashed nodes in both tries but different value
)
for path, blobA := range nodesA {
if blobB, ok := nodesB[path]; ok {
if bytes.Equal(blobA, blobB) {
continue
}
both[path] = blobA
continue
}
inA[path] = blobA
}
for path, blobB := range nodesB {
if _, ok := nodesA[path]; ok {
continue
}
inB[path] = blobB
}
return inA, inB, both
}
func packValue(val int64) []byte { func packValue(val int64) []byte {
acct := &types.StateAccount{ acct := &types.StateAccount{
Balance: big.NewInt(val), Balance: big.NewInt(val),
@ -51,27 +141,19 @@ func packValue(val int64) []byte {
return acct_rlp return acct_rlp
} }
func unpackValue(val []byte) int64 { func updateTrie(tr *gethtrie.Trie, vals []kvsi) (kvMap, error) {
var acct types.StateAccount
if err := rlp.DecodeBytes(val, &acct); err != nil {
panic(err)
}
return acct.Balance.Int64()
}
func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) {
all := kvMap{} all := kvMap{}
for _, val := range vals { for _, val := range vals {
all[string(val.k)] = &kv{[]byte(val.k), val.v} all[string(val.k)] = &kvi{[]byte(val.k), val.v}
tr.Update([]byte(val.k), packValue(val.v)) tr.Update([]byte(val.k), packValue(val.v))
} }
return all, nil return all, nil
} }
func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash { func commitTrie(t testing.TB, db *gethtrie.Database, tr *gethtrie.Trie) common.Hash {
t.Helper() t.Helper()
root, nodes := tr.Commit(false) root, nodes := tr.Commit(false)
if err := db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil { if err := db.Update(gethtrie.NewWithNodeSet(nodes)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := db.Commit(root, false); err != nil { if err := db.Commit(root, false); err != nil {
@ -80,16 +162,8 @@ func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common
return root return root
} }
// commit a LevelDB state trie, index to IPLD and return new trie func makePgIpfsEthDB(t testing.TB) ethdb.Database {
func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie { pg_db, err := postgres.ConnectSQLX(context.Background(), dbConfig)
t.Helper()
dbConfig.Driver = postgres.PGX
err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root)
if err != nil {
t.Fatal(err)
}
pg_db, err := postgres.ConnectSQLX(ctx, dbConfig)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -98,62 +172,48 @@ func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie {
t.Fatal(err) t.Fatal(err)
} }
}) })
return pgipfsethdb.NewDatabase(pg_db, internal.MakeCacheConfig(t))
}
ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t)) // commit a LevelDB state trie, index to IPLD and return new trie
sdb_db := state.NewDatabase(ipfs_db) func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *Trie {
tr, err := newStateTrie(root, sdb_db.TrieDB()) t.Helper()
dbConfig.Driver = postgres.PGX
err := helper.IndexStateDiff(dbConfig, gethstate.NewDatabase(edb), common.Hash{}, root)
if err != nil {
t.Fatal(err)
}
ipfs_db := makePgIpfsEthDB(t)
tr, err := New(TrieID(root), NewDatabase(ipfs_db), StateTrieCodec)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return tr return tr
} }
func newStateTrie(root common.Hash, db *trie.Database) (*trie.Trie, error) {
tr, err := trie.New(common.Hash{}, root, db, ipld.MEthStateTrie)
if err != nil {
return nil, err
}
return tr, nil
}
// generates a random Geth LevelDB trie of n key-value pairs and corresponding value map // generates a random Geth LevelDB trie of n key-value pairs and corresponding value map
func randomGethTrie(n int, db *geth_trie.Database) (*geth_trie.Trie, kvMap) { func randomGethTrie(n int, db *gethtrie.Database) (*gethtrie.Trie, kvMap) {
trie := geth_trie.NewEmpty(db) trie := gethtrie.NewEmpty(db)
var vals []*kv var vals []*kvi
for i := byte(0); i < 100; i++ { for i := byte(0); i < 100; i++ {
e := &kv{common.LeftPadBytes([]byte{i}, 32), int64(i)} e := &kvi{common.LeftPadBytes([]byte{i}, 32), int64(i)}
e2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)} e2 := &kvi{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)}
vals = append(vals, e, e2) vals = append(vals, e, e2)
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
k := randBytes(32) k := randBytes(32)
v := rand.Int63() v := rand.Int63()
vals = append(vals, &kv{k, v}) vals = append(vals, &kvi{k, v})
} }
all := kvMap{} all := kvMap{}
for _, val := range vals { for _, val := range vals {
all[string(val.k)] = &kv{[]byte(val.k), val.v} all[string(val.k)] = &kvi{[]byte(val.k), val.v}
trie.Update([]byte(val.k), packValue(val.v)) trie.Update([]byte(val.k), packValue(val.v))
} }
return trie, all return trie, all
} }
// generates a random IPLD-indexed trie
func randomTrie(t testing.TB, n int) (*trie.Trie, kvMap) {
edb := rawdb.NewMemoryDatabase()
db := geth_trie.NewDatabase(edb)
orig, vals := randomGethTrie(n, db)
root := commitTrie(t, db, orig)
trie := indexTrie(t, edb, root)
return trie, vals
}
func randBytes(n int) []byte {
r := make([]byte, n)
rand.Read(r)
return r
}
// TearDownDB is used to tear down the watcher dbs after tests // TearDownDB is used to tear down the watcher dbs after tests
func TearDownDB(db *sqlx.DB) error { func TearDownDB(db *sqlx.DB) error {
tx, err := db.Beginx() tx, err := db.Beginx()
@ -179,12 +239,3 @@ func TearDownDB(db *sqlx.DB) error {
} }
return tx.Commit() return tx.Commit()
} }
// returns a cache config with unique name (groupcache names are global)
func makeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
return pgipfsethdb.CacheConfig{
Name: t.Name(),
Size: 3000000, // 3MB
ExpiryDuration: time.Hour,
}
}