diff --git a/go.mod b/go.mod
index 6c1f715..67d2676 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.18
require (
github.com/VictoriaMetrics/fastcache v1.6.0
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha
+ github.com/davecgh/go-spew v1.1.1
github.com/ethereum/go-ethereum v1.11.5
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/ipfs/go-cid v0.2.0
@@ -21,13 +22,13 @@ require (
github.com/DataDog/zstd v1.5.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect
+ github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cockroachdb/errors v1.9.1 // indirect
github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect
github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect
github.com/cockroachdb/redact v1.1.3 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
- github.com/davecgh/go-spew v1.1.1 // indirect
github.com/deckarep/golang-set/v2 v2.1.0 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect
github.com/edsrzf/mmap-go v1.0.0 // indirect
@@ -112,4 +113,4 @@ require (
lukechampine.com/blake3 v1.1.6 // indirect
)
-replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha
+replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.4
diff --git a/go.sum b/go.sum
index fef91fa..e7e8faa 100644
--- a/go.sum
+++ b/go.sum
@@ -17,12 +17,13 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5
github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
-github.com/btcsuite/btcd v0.22.0-beta h1:LTDpDKUM5EeOFBPM8IXpinEcmZ6FWfNZbE3lfrfdnWo=
github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k=
github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU=
+github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 h1:KdUfX2zKommPRa+PD0sWZUyXe9w277ABlgELO7H04IM=
+github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
-github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha h1:x9muuG0Z2W/UAkwAq+0SXhYG9MCP6SJlGEfFNQa6JPI=
-github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek=
+github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha h1:rhRmK/NeWMnQ07E4DuLb7WSh9FMotlXMPPaOrf8GJwM=
+github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek=
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha h1:I1iXTaIjbTH8ehzNXmT2waXcYBifi1yjK6FK3W3a0Pg=
github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha/go.mod h1:EGAdV/YewEADFDDVF1k9GNwy8vNWR29Xb87sRHgMIng=
github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk=
diff --git a/internal/util.go b/internal/util.go
index a19df0a..634c18a 100644
--- a/internal/util.go
+++ b/internal/util.go
@@ -1,6 +1,10 @@
package internal
import (
+ "testing"
+ "time"
+
+ pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
"github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash"
)
@@ -12,3 +16,12 @@ func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) {
}
return cid.NewCidV1(codec, buf), nil
}
+
+// returns a cache config with unique name (groupcache names are global)
+func MakeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
+ return pgipfsethdb.CacheConfig{
+ Name: t.Name(),
+ Size: 3000000, // 3MB
+ ExpiryDuration: time.Hour,
+ }
+}
diff --git a/trie_by_cid/doc.go b/trie_by_cid/doc.go
new file mode 100644
index 0000000..4b9b440
--- /dev/null
+++ b/trie_by_cid/doc.go
@@ -0,0 +1,3 @@
+// This package is a near complete copy of go-ethereum/trie and go-ethereum/core/state, modified to use
+// a v0 IPFS blockstore as the backing DB, i.e. DB values are indexed by CID rather than hash.
+package trie_by_cid
diff --git a/trie_by_cid/helper/statediff_helper.go b/trie_by_cid/helper/statediff_helper.go
index b03abf6..030c2b3 100644
--- a/trie_by_cid/helper/statediff_helper.go
+++ b/trie_by_cid/helper/statediff_helper.go
@@ -20,7 +20,10 @@ var (
mockTD = big.NewInt(1)
)
-func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error {
+// IndexStateDiff indexes a single statediff.
+// - uses TestChainConfig
+// - block hash/number are left as zero
+func IndexStateDiff(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error {
_, indexer, err := indexer.NewStateDiffIndexer(
context.Background(), ChainConfig, node.Info{}, dbConfig)
if err != nil {
@@ -28,11 +31,10 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root
}
defer indexer.Close() // fixme: hangs when using PGX driver
- // generating statediff payload for block, and transform the data into Postgres
builder := statediff.NewBuilder(stateCache)
block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher())
- // todo: use dummy block hashes to just produce trie structure for testing
+ // uses zero block hash/number, we only need the trie structure here
args := statediff.Args{
OldStateRoot: rootA,
NewStateRoot: rootB,
@@ -45,12 +47,7 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root
if err != nil {
return err
}
- // for _, node := range diff.Nodes {
- // err := indexer.PushStateNode(tx, node, block.Hash().String())
- // if err != nil {
- // return err
- // }
- // }
+ // we don't need to index diff.Nodes since we are just interested in the trie
for _, ipld := range diff.IPLDs {
if err := indexer.PushIPLD(tx, ipld); err != nil {
return err
diff --git a/trie_by_cid/state/access_list.go b/trie_by_cid/state/access_list.go
new file mode 100644
index 0000000..4194691
--- /dev/null
+++ b/trie_by_cid/state/access_list.go
@@ -0,0 +1,136 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "github.com/ethereum/go-ethereum/common"
+)
+
+type accessList struct {
+ addresses map[common.Address]int
+ slots []map[common.Hash]struct{}
+}
+
+// ContainsAddress returns true if the address is in the access list.
+func (al *accessList) ContainsAddress(address common.Address) bool {
+ _, ok := al.addresses[address]
+ return ok
+}
+
+// Contains checks if a slot within an account is present in the access list, returning
+// separate flags for the presence of the account and the slot respectively.
+func (al *accessList) Contains(address common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) {
+ idx, ok := al.addresses[address]
+ if !ok {
+ // no such address (and hence zero slots)
+ return false, false
+ }
+ if idx == -1 {
+ // address yes, but no slots
+ return true, false
+ }
+ _, slotPresent = al.slots[idx][slot]
+ return true, slotPresent
+}
+
+// newAccessList creates a new accessList.
+func newAccessList() *accessList {
+ return &accessList{
+ addresses: make(map[common.Address]int),
+ }
+}
+
+// Copy creates an independent copy of an accessList.
+func (a *accessList) Copy() *accessList {
+ cp := newAccessList()
+ for k, v := range a.addresses {
+ cp.addresses[k] = v
+ }
+ cp.slots = make([]map[common.Hash]struct{}, len(a.slots))
+ for i, slotMap := range a.slots {
+ newSlotmap := make(map[common.Hash]struct{}, len(slotMap))
+ for k := range slotMap {
+ newSlotmap[k] = struct{}{}
+ }
+ cp.slots[i] = newSlotmap
+ }
+ return cp
+}
+
+// AddAddress adds an address to the access list, and returns 'true' if the operation
+// caused a change (addr was not previously in the list).
+func (al *accessList) AddAddress(address common.Address) bool {
+ if _, present := al.addresses[address]; present {
+ return false
+ }
+ al.addresses[address] = -1
+ return true
+}
+
+// AddSlot adds the specified (addr, slot) combo to the access list.
+// Return values are:
+// - address added
+// - slot added
+// For any 'true' value returned, a corresponding journal entry must be made.
+func (al *accessList) AddSlot(address common.Address, slot common.Hash) (addrChange bool, slotChange bool) {
+ idx, addrPresent := al.addresses[address]
+ if !addrPresent || idx == -1 {
+ // Address not present, or addr present but no slots there
+ al.addresses[address] = len(al.slots)
+ slotmap := map[common.Hash]struct{}{slot: {}}
+ al.slots = append(al.slots, slotmap)
+ return !addrPresent, true
+ }
+ // There is already an (address,slot) mapping
+ slotmap := al.slots[idx]
+ if _, ok := slotmap[slot]; !ok {
+ slotmap[slot] = struct{}{}
+ // Journal add slot change
+ return false, true
+ }
+ // No changes required
+ return false, false
+}
+
+// DeleteSlot removes an (address, slot)-tuple from the access list.
+// This operation needs to be performed in the same order as the addition happened.
+// This method is meant to be used by the journal, which maintains ordering of
+// operations.
+func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
+ idx, addrOk := al.addresses[address]
+ // There are two ways this can fail
+ if !addrOk {
+ panic("reverting slot change, address not present in list")
+ }
+ slotmap := al.slots[idx]
+ delete(slotmap, slot)
+ // If that was the last (first) slot, remove it
+ // Since additions and rollbacks are always performed in order,
+ // we can delete the item without worrying about screwing up later indices
+ if len(slotmap) == 0 {
+ al.slots = al.slots[:idx]
+ al.addresses[address] = -1
+ }
+}
+
+// DeleteAddress removes an address from the access list. This operation
+// needs to be performed in the same order as the addition happened.
+// This method is meant to be used by the journal, which maintains ordering of
+// operations.
+func (al *accessList) DeleteAddress(address common.Address) {
+ delete(al.addresses, address)
+}
diff --git a/trie_by_cid/state/database.go b/trie_by_cid/state/database.go
index 30adbb7..e6920da 100644
--- a/trie_by_cid/state/database.go
+++ b/trie_by_cid/state/database.go
@@ -1,14 +1,30 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
package state
import (
"errors"
+ "fmt"
- "github.com/VictoriaMetrics/fastcache"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/common/lru"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld"
- lru "github.com/hashicorp/golang-lru"
"github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
@@ -28,7 +44,10 @@ type Database interface {
OpenTrie(root common.Hash) (Trie, error)
// OpenStorageTrie opens the storage trie of an account.
- OpenStorageTrie(addrHash, root common.Hash) (Trie, error)
+ OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error)
+
+ // CopyTrie returns an independent copy of the given trie.
+ CopyTrie(Trie) Trie
// ContractCode retrieves a particular contract's code.
ContractCode(codeHash common.Hash) ([]byte, error)
@@ -36,17 +55,79 @@ type Database interface {
// ContractCodeSize retrieves a particular contracts code's size.
ContractCodeSize(codeHash common.Hash) (int, error)
+ // DiskDB returns the underlying key-value disk database.
+ DiskDB() ethdb.KeyValueStore
+
// TrieDB retrieves the low level trie database used for data storage.
TrieDB() *trie.Database
}
// Trie is a Ethereum Merkle Patricia trie.
type Trie interface {
+ // GetKey returns the sha3 preimage of a hashed key that was previously used
+ // to store a value.
+ //
+ // TODO(fjl): remove this when StateTrie is removed
+ GetKey([]byte) []byte
+
+ // TryGet returns the value for key stored in the trie. The value bytes must
+ // not be modified by the caller. If a node was not found in the database, a
+ // trie.MissingNodeError is returned.
TryGet(key []byte) ([]byte, error)
+
+ // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
+ // possible to use keybyte-encoding as the path might contain odd nibbles.
TryGetNode(path []byte) ([]byte, int, error)
- TryGetAccount(key []byte) (*types.StateAccount, error)
+
+ // TryGetAccount abstracts an account read from the trie. It retrieves the
+ // account blob from the trie with provided account address and decodes it
+ // with associated decoding algorithm. If the specified account is not in
+ // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes
+ // are missing or the account blob is incorrect for decoding), an error will
+ // be returned.
+ TryGetAccount(address common.Address) (*types.StateAccount, error)
+
+ // TryUpdate associates key with value in the trie. If value has length zero, any
+ // existing value is deleted from the trie. The value bytes must not be modified
+ // by the caller while they are stored in the trie. If a node was not found in the
+ // database, a trie.MissingNodeError is returned.
+ TryUpdate(key, value []byte) error
+
+ // TryUpdateAccount abstracts an account write to the trie. It encodes the
+ // provided account object with associated algorithm and then updates it
+ // in the trie with provided address.
+ TryUpdateAccount(address common.Address, account *types.StateAccount) error
+
+ // TryDelete removes any existing value for key from the trie. If a node was not
+ // found in the database, a trie.MissingNodeError is returned.
+ TryDelete(key []byte) error
+
+ // TryDeleteAccount abstracts an account deletion from the trie.
+ TryDeleteAccount(address common.Address) error
+
+ // Hash returns the root hash of the trie. It does not write to the database and
+ // can be used even if the trie doesn't have one.
Hash() common.Hash
+
+ // Commit collects all dirty nodes in the trie and replace them with the
+ // corresponding node hash. All collected nodes(including dirty leaves if
+ // collectLeaf is true) will be encapsulated into a nodeset for return.
+ // The returned nodeset can be nil if the trie is clean(nothing to commit).
+ // Once the trie is committed, it's not usable anymore. A new trie must
+ // be created with new root and updated trie database for following usage
+ Commit(collectLeaf bool) (common.Hash, *trie.NodeSet)
+
+ // NodeIterator returns an iterator that returns nodes of the trie. Iteration
+ // starts at the key after the given start key.
NodeIterator(startKey []byte) trie.NodeIterator
+
+ // Prove constructs a Merkle proof for key. The result contains all encoded nodes
+ // on the path to the value at key. The value itself is also included in the last
+ // node and can be retrieved by verifying the proof.
+ //
+ // If the trie does not contain a value for key, the returned proof contains all
+ // nodes of the longest existing prefix of the key (at least the root), ending
+ // with the node that proves the absence of the key.
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error
}
@@ -61,23 +142,34 @@ func NewDatabase(db ethdb.Database) Database {
// is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a
// large memory cache.
func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database {
- csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{
- db: trie.NewDatabaseWithConfig(db, config),
- codeSizeCache: csc,
- codeCache: fastcache.New(codeCacheSize),
+ disk: db,
+ codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
+ codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
+ triedb: trie.NewDatabaseWithConfig(db, config),
+ }
+}
+
+// NewDatabaseWithNodeDB creates a state database with an already initialized node database.
+func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database {
+ return &cachingDB{
+ disk: db,
+ codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize),
+ codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize),
+ triedb: triedb,
}
}
type cachingDB struct {
- db *trie.Database
- codeSizeCache *lru.Cache
- codeCache *fastcache.Cache
+ disk ethdb.KeyValueStore
+ codeSizeCache *lru.Cache[common.Hash, int]
+ codeCache *lru.SizeConstrainedCache[common.Hash, []byte]
+ triedb *trie.Database
}
// OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
- tr, err := trie.NewStateTrie(common.Hash{}, root, db.db)
+ tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb, trie.StateTrieCodec)
if err != nil {
return nil, err
}
@@ -85,29 +177,40 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
}
// OpenStorageTrie opens the storage trie of an account.
-func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
- tr, err := trie.NewStorageTrie(addrHash, root, db.db)
+func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) {
+ tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb, trie.StorageTrieCodec)
if err != nil {
return nil, err
}
return tr, nil
}
+// CopyTrie returns an independent copy of the given trie.
+func (db *cachingDB) CopyTrie(t Trie) Trie {
+ switch t := t.(type) {
+ case *trie.StateTrie:
+ return t.Copy()
+ default:
+ panic(fmt.Errorf("unknown trie type %T", t))
+ }
+}
+
// ContractCode retrieves a particular contract's code.
func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
- if code := db.codeCache.Get(nil, codeHash.Bytes()); len(code) > 0 {
+ code, _ := db.codeCache.Get(codeHash)
+ if len(code) > 0 {
return code, nil
}
- codeCID, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes())
+ cid, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes())
if err != nil {
return nil, err
}
- code, err := db.db.DiskDB().Get(codeCID.Bytes())
+ code, err = db.disk.Get(cid.Bytes())
if err != nil {
return nil, err
}
if len(code) > 0 {
- db.codeCache.Set(codeHash.Bytes(), code)
+ db.codeCache.Add(codeHash, code)
db.codeSizeCache.Add(codeHash, len(code))
return code, nil
}
@@ -117,13 +220,18 @@ func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) {
// ContractCodeSize retrieves a particular contracts code's size.
func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok {
- return cached.(int), nil
+ return cached, nil
}
code, err := db.ContractCode(codeHash)
return len(code), err
}
+// DiskDB returns the underlying key-value disk database.
+func (db *cachingDB) DiskDB() ethdb.KeyValueStore {
+ return db.disk
+}
+
// TrieDB retrieves any intermediate trie-node caching layer.
func (db *cachingDB) TrieDB() *trie.Database {
- return db.db
+ return db.triedb
}
diff --git a/trie_by_cid/state/journal.go b/trie_by_cid/state/journal.go
new file mode 100644
index 0000000..1722fb4
--- /dev/null
+++ b/trie_by_cid/state/journal.go
@@ -0,0 +1,282 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "math/big"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// journalEntry is a modification entry in the state change journal that can be
+// reverted on demand.
+type journalEntry interface {
+ // revert undoes the changes introduced by this journal entry.
+ revert(*StateDB)
+
+ // dirtied returns the Ethereum address modified by this journal entry.
+ dirtied() *common.Address
+}
+
+// journal contains the list of state modifications applied since the last state
+// commit. These are tracked to be able to be reverted in the case of an execution
+// exception or request for reversal.
+type journal struct {
+ entries []journalEntry // Current changes tracked by the journal
+ dirties map[common.Address]int // Dirty accounts and the number of changes
+}
+
+// newJournal creates a new initialized journal.
+func newJournal() *journal {
+ return &journal{
+ dirties: make(map[common.Address]int),
+ }
+}
+
+// append inserts a new modification entry to the end of the change journal.
+func (j *journal) append(entry journalEntry) {
+ j.entries = append(j.entries, entry)
+ if addr := entry.dirtied(); addr != nil {
+ j.dirties[*addr]++
+ }
+}
+
+// revert undoes a batch of journalled modifications along with any reverted
+// dirty handling too.
+func (j *journal) revert(statedb *StateDB, snapshot int) {
+ for i := len(j.entries) - 1; i >= snapshot; i-- {
+ // Undo the changes made by the operation
+ j.entries[i].revert(statedb)
+
+ // Drop any dirty tracking induced by the change
+ if addr := j.entries[i].dirtied(); addr != nil {
+ if j.dirties[*addr]--; j.dirties[*addr] == 0 {
+ delete(j.dirties, *addr)
+ }
+ }
+ }
+ j.entries = j.entries[:snapshot]
+}
+
+// dirty explicitly sets an address to dirty, even if the change entries would
+// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD
+// precompile consensus exception.
+func (j *journal) dirty(addr common.Address) {
+ j.dirties[addr]++
+}
+
+// length returns the current number of entries in the journal.
+func (j *journal) length() int {
+ return len(j.entries)
+}
+
+type (
+ // Changes to the account trie.
+ createObjectChange struct {
+ account *common.Address
+ }
+ resetObjectChange struct {
+ prev *stateObject
+ prevdestruct bool
+ }
+ suicideChange struct {
+ account *common.Address
+ prev bool // whether account had already suicided
+ prevbalance *big.Int
+ }
+
+ // Changes to individual accounts.
+ balanceChange struct {
+ account *common.Address
+ prev *big.Int
+ }
+ nonceChange struct {
+ account *common.Address
+ prev uint64
+ }
+ storageChange struct {
+ account *common.Address
+ key, prevalue common.Hash
+ }
+ codeChange struct {
+ account *common.Address
+ prevcode, prevhash []byte
+ }
+
+ // Changes to other state values.
+ refundChange struct {
+ prev uint64
+ }
+ addLogChange struct {
+ txhash common.Hash
+ }
+ addPreimageChange struct {
+ hash common.Hash
+ }
+ touchChange struct {
+ account *common.Address
+ }
+ // Changes to the access list
+ accessListAddAccountChange struct {
+ address *common.Address
+ }
+ accessListAddSlotChange struct {
+ address *common.Address
+ slot *common.Hash
+ }
+
+ transientStorageChange struct {
+ account *common.Address
+ key, prevalue common.Hash
+ }
+)
+
+func (ch createObjectChange) revert(s *StateDB) {
+ delete(s.stateObjects, *ch.account)
+ delete(s.stateObjectsDirty, *ch.account)
+}
+
+func (ch createObjectChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch resetObjectChange) revert(s *StateDB) {
+ s.setStateObject(ch.prev)
+ if !ch.prevdestruct {
+ delete(s.stateObjectsDestruct, ch.prev.address)
+ }
+}
+
+func (ch resetObjectChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch suicideChange) revert(s *StateDB) {
+ obj := s.getStateObject(*ch.account)
+ if obj != nil {
+ obj.suicided = ch.prev
+ obj.setBalance(ch.prevbalance)
+ }
+}
+
+func (ch suicideChange) dirtied() *common.Address {
+ return ch.account
+}
+
+var ripemd = common.HexToAddress("0000000000000000000000000000000000000003")
+
+func (ch touchChange) revert(s *StateDB) {
+}
+
+func (ch touchChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch balanceChange) revert(s *StateDB) {
+ s.getStateObject(*ch.account).setBalance(ch.prev)
+}
+
+func (ch balanceChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch nonceChange) revert(s *StateDB) {
+ s.getStateObject(*ch.account).setNonce(ch.prev)
+}
+
+func (ch nonceChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch codeChange) revert(s *StateDB) {
+ s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
+}
+
+func (ch codeChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch storageChange) revert(s *StateDB) {
+ s.getStateObject(*ch.account).setState(ch.key, ch.prevalue)
+}
+
+func (ch storageChange) dirtied() *common.Address {
+ return ch.account
+}
+
+func (ch transientStorageChange) revert(s *StateDB) {
+ s.setTransientState(*ch.account, ch.key, ch.prevalue)
+}
+
+func (ch transientStorageChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch refundChange) revert(s *StateDB) {
+ s.refund = ch.prev
+}
+
+func (ch refundChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch addLogChange) revert(s *StateDB) {
+ logs := s.logs[ch.txhash]
+ if len(logs) == 1 {
+ delete(s.logs, ch.txhash)
+ } else {
+ s.logs[ch.txhash] = logs[:len(logs)-1]
+ }
+ s.logSize--
+}
+
+func (ch addLogChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch addPreimageChange) revert(s *StateDB) {
+ delete(s.preimages, ch.hash)
+}
+
+func (ch addPreimageChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch accessListAddAccountChange) revert(s *StateDB) {
+ /*
+ One important invariant here, is that whenever a (addr, slot) is added, if the
+ addr is not already present, the add causes two journal entries:
+ - one for the address,
+ - one for the (address,slot)
+ Therefore, when unrolling the change, we can always blindly delete the
+ (addr) at this point, since no storage adds can remain when come upon
+ a single (addr) change.
+ */
+ s.accessList.DeleteAddress(*ch.address)
+}
+
+func (ch accessListAddAccountChange) dirtied() *common.Address {
+ return nil
+}
+
+func (ch accessListAddSlotChange) revert(s *StateDB) {
+ s.accessList.DeleteSlot(*ch.address, *ch.slot)
+}
+
+func (ch accessListAddSlotChange) dirtied() *common.Address {
+ return nil
+}
diff --git a/trie_by_cid/state/metrics.go b/trie_by_cid/state/metrics.go
new file mode 100644
index 0000000..e702ef3
--- /dev/null
+++ b/trie_by_cid/state/metrics.go
@@ -0,0 +1,30 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import "github.com/ethereum/go-ethereum/metrics"
+
+var (
+ accountUpdatedMeter = metrics.NewRegisteredMeter("state/update/account", nil)
+ storageUpdatedMeter = metrics.NewRegisteredMeter("state/update/storage", nil)
+ accountDeletedMeter = metrics.NewRegisteredMeter("state/delete/account", nil)
+ storageDeletedMeter = metrics.NewRegisteredMeter("state/delete/storage", nil)
+ accountTrieUpdatedMeter = metrics.NewRegisteredMeter("state/update/accountnodes", nil)
+ storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil)
+ accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil)
+ storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil)
+)
diff --git a/trie_by_cid/state/state_object.go b/trie_by_cid/state/state_object.go
index 6b53c67..e1afcb3 100644
--- a/trie_by_cid/state/state_object.go
+++ b/trie_by_cid/state/state_object.go
@@ -1,8 +1,25 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
package state
import (
"bytes"
"fmt"
+ "io"
"math/big"
"time"
@@ -11,15 +28,7 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/rlp"
-)
-
-var (
- // emptyRoot is the known root hash of an empty trie.
- // this is calculated as: emptyRoot = crypto.Keccak256(rlp.Encode([][]byte{}))
- // that is, the keccak356 hash of the rlp encoding of an empty trie node (empty byte slice array)
- emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
- // emptyCodeHash is the CodeHash for an EOA, for an account without contract code deployed
- emptyCodeHash = crypto.Keccak256(nil)
+ // "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
)
type Code []byte
@@ -34,7 +43,6 @@ func (s Storage) String() (str string) {
for key, value := range s {
str += fmt.Sprintf("%X : %X\n", key, value)
}
-
return
}
@@ -43,32 +51,39 @@ func (s Storage) Copy() Storage {
for key, value := range s {
cpy[key] = value
}
-
return cpy
}
-// stateObject represents an Ethereum account which is being accessed.
+// stateObject represents an Ethereum account which is being modified.
//
// The usage pattern is as follows:
// First you need to obtain a state object.
-// Account values can be accessed through the object.
+// Account values can be accessed and modified through the object.
type stateObject struct {
address common.Address
addrHash common.Hash // hash of ethereum address of the account
data types.StateAccount
db *StateDB
- // Caches.
+ // Write caches.
trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded
- originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
- fakeStorage Storage // Fake storage which constructed by caller for debugging purpose.
+ originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction
+ pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block
+ dirtyStorage Storage // Storage entries that have been modified in the current transaction execution
+
+ // Cache flags.
+ // When an object is marked suicided it will be deleted from the trie
+ // during the "update" phase of the state transition.
+ dirtyCode bool // true if the code was updated
+ suicided bool
+ deleted bool
}
// empty returns whether the account is considered empty.
func (s *stateObject) empty() bool {
- return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash)
+ return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes())
}
// newObject creates a state object.
@@ -77,71 +92,128 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *st
data.Balance = new(big.Int)
}
if data.CodeHash == nil {
- data.CodeHash = emptyCodeHash
+ data.CodeHash = types.EmptyCodeHash.Bytes()
}
if data.Root == (common.Hash{}) {
- data.Root = emptyRoot
+ data.Root = types.EmptyRootHash
}
return &stateObject{
- db: db,
- address: address,
- addrHash: crypto.Keccak256Hash(address[:]),
- data: data,
- originStorage: make(Storage),
+ db: db,
+ address: address,
+ addrHash: crypto.Keccak256Hash(address[:]),
+ data: data,
+ originStorage: make(Storage),
+ pendingStorage: make(Storage),
+ dirtyStorage: make(Storage),
}
}
-// setError remembers the first non-nil error it is called with.
-func (s *stateObject) setError(err error) {
- s.db.setError(err)
+// EncodeRLP implements rlp.Encoder.
+func (s *stateObject) EncodeRLP(w io.Writer) error {
+ return rlp.Encode(w, &s.data)
}
-func (s *stateObject) getTrie(db Database) Trie {
+func (s *stateObject) markSuicided() {
+ s.suicided = true
+}
+
+func (s *stateObject) touch() {
+ s.db.journal.append(touchChange{
+ account: &s.address,
+ })
+ if s.address == ripemd {
+ // Explicitly put it in the dirty-cache, which is otherwise generated from
+ // flattened journals.
+ s.db.journal.dirty(s.address)
+ }
+}
+
+// getTrie returns the associated storage trie. The trie will be opened
+// if it's not loaded previously. An error will be returned if trie can't
+// be loaded.
+func (s *stateObject) getTrie(db Database) (Trie, error) {
if s.trie == nil {
- // // Try fetching from prefetcher first
- // // We don't prefetch empty tries
- // if s.data.Root != emptyRoot && s.db.prefetcher != nil {
- // // When the miner is creating the pending state, there is no
- // // prefetcher
- // s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
- // }
+ // Try fetching from prefetcher first
+ // We don't prefetch empty tries
+ if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil {
+ // When the miner is creating the pending state, there is no
+ // prefetcher
+ s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root)
+ }
if s.trie == nil {
- var err error
- s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root)
+ tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root)
if err != nil {
- s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{})
- s.setError(fmt.Errorf("can't create storage trie: %w", err))
+ return nil, err
}
+ s.trie = tr
}
}
- return s.trie
+ return s.trie, nil
+}
+
+// GetState retrieves a value from the account storage trie.
+func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
+ // If we have a dirty value for this state entry, return it
+ value, dirty := s.dirtyStorage[key]
+ if dirty {
+ return value
+ }
+ // Otherwise return the entry's original value
+ return s.GetCommittedState(db, key)
}
// GetCommittedState retrieves a value from the committed account storage trie.
-func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
- // If the fake storage is set, only lookup the state here(in the debugging mode)
- if s.fakeStorage != nil {
- return s.fakeStorage[key]
+func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash {
+ // If we have a pending write or clean cached, return that
+ if value, pending := s.pendingStorage[key]; pending {
+ return value
}
- // If we have a cached value, return that
if value, cached := s.originStorage[key]; cached {
return value
}
- // If no live objects are available, load from the database.
- start := time.Now()
- enc, err := s.getTrie(db).TryGet(key.Bytes())
- if metrics.EnabledExpensive {
- s.db.StorageReads += time.Since(start)
- }
- if err != nil {
- s.setError(err)
+ // If the object was destructed in *this* block (and potentially resurrected),
+ // the storage has been cleared out, and we should *not* consult the previous
+ // database about any storage values. The only possible alternatives are:
+ // 1) resurrect happened, and new slot values were set -- those should
+ // have been handles via pendingStorage above.
+ // 2) we don't have new values, and can deliver empty response back
+ if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed {
return common.Hash{}
}
+ // If no live objects are available, attempt to use snapshots
+ var (
+ enc []byte
+ err error
+ )
+ if s.db.snap != nil {
+ start := time.Now()
+ enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes()))
+ if metrics.EnabledExpensive {
+ s.db.SnapshotStorageReads += time.Since(start)
+ }
+ }
+ // If the snapshot is unavailable or reading from it fails, load from the database.
+ if s.db.snap == nil || err != nil {
+ start := time.Now()
+ tr, err := s.getTrie(db)
+ if err != nil {
+ s.db.setError(err)
+ return common.Hash{}
+ }
+ enc, err = tr.TryGet(key.Bytes())
+ if metrics.EnabledExpensive {
+ s.db.StorageReads += time.Since(start)
+ }
+ if err != nil {
+ s.db.setError(err)
+ return common.Hash{}
+ }
+ }
var value common.Hash
if len(enc) > 0 {
_, content, _, err := rlp.Split(enc)
if err != nil {
- s.setError(err)
+ s.db.setError(err)
}
value.SetBytes(content)
}
@@ -149,6 +221,182 @@ func (s *stateObject) GetState(db Database, key common.Hash) common.Hash {
return value
}
+// SetState updates a value in account storage.
+func (s *stateObject) SetState(db Database, key, value common.Hash) {
+ // If the new value is the same as old, don't set
+ prev := s.GetState(db, key)
+ if prev == value {
+ return
+ }
+ // New value is different, update and journal the change
+ s.db.journal.append(storageChange{
+ account: &s.address,
+ key: key,
+ prevalue: prev,
+ })
+ s.setState(key, value)
+}
+
+func (s *stateObject) setState(key, value common.Hash) {
+ s.dirtyStorage[key] = value
+}
+
+// finalise moves all dirty storage slots into the pending area to be hashed or
+// committed later. It is invoked at the end of every transaction.
+func (s *stateObject) finalise(prefetch bool) {
+ slotsToPrefetch := make([][]byte, 0, len(s.dirtyStorage))
+ for key, value := range s.dirtyStorage {
+ s.pendingStorage[key] = value
+ if value != s.originStorage[key] {
+ slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure
+ }
+ }
+ if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash {
+ s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch)
+ }
+ if len(s.dirtyStorage) > 0 {
+ s.dirtyStorage = make(Storage)
+ }
+}
+
+// updateTrie writes cached storage modifications into the object's storage trie.
+// It will return nil if the trie has not been loaded and no changes have been
+// made. An error will be returned if the trie can't be loaded/updated correctly.
+func (s *stateObject) updateTrie(db Database) (Trie, error) {
+ // Make sure all dirty slots are finalized into the pending storage area
+ s.finalise(false) // Don't prefetch anymore, pull directly if need be
+ if len(s.pendingStorage) == 0 {
+ return s.trie, nil
+ }
+ // Track the amount of time wasted on updating the storage trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now())
+ }
+ // The snapshot storage map for the object
+ var (
+ storage map[common.Hash][]byte
+ hasher = s.db.hasher
+ )
+ tr, err := s.getTrie(db)
+ if err != nil {
+ s.db.setError(err)
+ return nil, err
+ }
+ // Insert all the pending updates into the trie
+ usedStorage := make([][]byte, 0, len(s.pendingStorage))
+ for key, value := range s.pendingStorage {
+ // Skip noop changes, persist actual changes
+ if value == s.originStorage[key] {
+ continue
+ }
+ s.originStorage[key] = value
+
+ var v []byte
+ if (value == common.Hash{}) {
+ if err := tr.TryDelete(key[:]); err != nil {
+ s.db.setError(err)
+ return nil, err
+ }
+ s.db.StorageDeleted += 1
+ } else {
+ // Encoding []byte cannot fail, ok to ignore the error.
+ v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:]))
+ if err := tr.TryUpdate(key[:], v); err != nil {
+ s.db.setError(err)
+ return nil, err
+ }
+ s.db.StorageUpdated += 1
+ }
+ // If state snapshotting is active, cache the data til commit
+ if s.db.snap != nil {
+ if storage == nil {
+ // Retrieve the old storage map, if available, create a new one otherwise
+ if storage = s.db.snapStorage[s.addrHash]; storage == nil {
+ storage = make(map[common.Hash][]byte)
+ s.db.snapStorage[s.addrHash] = storage
+ }
+ }
+ storage[crypto.HashData(hasher, key[:])] = v // v will be nil if it's deleted
+ }
+ usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure
+ }
+ if s.db.prefetcher != nil {
+ s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage)
+ }
+ if len(s.pendingStorage) > 0 {
+ s.pendingStorage = make(Storage)
+ }
+ return tr, nil
+}
+
+// UpdateRoot sets the trie root to the current root hash of. An error
+// will be returned if trie root hash is not computed correctly.
+func (s *stateObject) updateRoot(db Database) {
+ tr, err := s.updateTrie(db)
+ if err != nil {
+ return
+ }
+ // If nothing changed, don't bother with hashing anything
+ if tr == nil {
+ return
+ }
+ // Track the amount of time wasted on hashing the storage trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now())
+ }
+ s.data.Root = tr.Hash()
+}
+
+// AddBalance adds amount to s's balance.
+// It is used to add funds to the destination account of a transfer.
+func (s *stateObject) AddBalance(amount *big.Int) {
+ // EIP161: We must check emptiness for the objects such that the account
+ // clearing (0,0,0 objects) can take effect.
+ if amount.Sign() == 0 {
+ if s.empty() {
+ s.touch()
+ }
+ return
+ }
+ s.SetBalance(new(big.Int).Add(s.Balance(), amount))
+}
+
+// SubBalance removes amount from s's balance.
+// It is used to remove funds from the origin account of a transfer.
+func (s *stateObject) SubBalance(amount *big.Int) {
+ if amount.Sign() == 0 {
+ return
+ }
+ s.SetBalance(new(big.Int).Sub(s.Balance(), amount))
+}
+
+func (s *stateObject) SetBalance(amount *big.Int) {
+ s.db.journal.append(balanceChange{
+ account: &s.address,
+ prev: new(big.Int).Set(s.data.Balance),
+ })
+ s.setBalance(amount)
+}
+
+func (s *stateObject) setBalance(amount *big.Int) {
+ s.data.Balance = amount
+}
+
+func (s *stateObject) deepCopy(db *StateDB) *stateObject {
+ stateObject := newObject(db, s.address, s.data)
+ if s.trie != nil {
+ stateObject.trie = db.db.CopyTrie(s.trie)
+ }
+ stateObject.code = s.code
+ stateObject.dirtyStorage = s.dirtyStorage.Copy()
+ stateObject.originStorage = s.originStorage.Copy()
+ stateObject.pendingStorage = s.pendingStorage.Copy()
+ stateObject.suicided = s.suicided
+ stateObject.dirtyCode = s.dirtyCode
+ stateObject.deleted = s.deleted
+ return stateObject
+}
+
//
// Attribute accessors
//
@@ -163,12 +411,12 @@ func (s *stateObject) Code(db Database) []byte {
if s.code != nil {
return s.code
}
- if bytes.Equal(s.CodeHash(), emptyCodeHash) {
+ if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return nil
}
code, err := db.ContractCode(common.BytesToHash(s.CodeHash()))
if err != nil {
- s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
+ s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err))
}
s.code = code
return code
@@ -181,16 +429,44 @@ func (s *stateObject) CodeSize(db Database) int {
if s.code != nil {
return len(s.code)
}
- if bytes.Equal(s.CodeHash(), emptyCodeHash) {
+ if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) {
return 0
}
size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash()))
if err != nil {
- s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
+ s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err))
}
return size
}
+func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
+ prevcode := s.Code(s.db.db)
+ s.db.journal.append(codeChange{
+ account: &s.address,
+ prevhash: s.CodeHash(),
+ prevcode: prevcode,
+ })
+ s.setCode(codeHash, code)
+}
+
+func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
+ s.code = code
+ s.data.CodeHash = codeHash[:]
+ s.dirtyCode = true
+}
+
+func (s *stateObject) SetNonce(nonce uint64) {
+ s.db.journal.append(nonceChange{
+ account: &s.address,
+ prev: s.data.Nonce,
+ })
+ s.setNonce(nonce)
+}
+
+func (s *stateObject) setNonce(nonce uint64) {
+ s.data.Nonce = nonce
+}
+
func (s *stateObject) CodeHash() []byte {
return s.data.CodeHash
}
@@ -202,10 +478,3 @@ func (s *stateObject) Balance() *big.Int {
func (s *stateObject) Nonce() uint64 {
return s.data.Nonce
}
-
-// Never called, but must be present to allow stateObject to be used
-// as a vm.Account interface that also satisfies the vm.ContractRef
-// interface. Interfaces are awesome.
-func (s *stateObject) Value() *big.Int {
- panic("Value on stateObject should never be called")
-}
diff --git a/trie_by_cid/state/state_object_test.go b/trie_by_cid/state/state_object_test.go
new file mode 100644
index 0000000..42fd778
--- /dev/null
+++ b/trie_by_cid/state/state_object_test.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+func BenchmarkCutOriginal(b *testing.B) {
+ value := common.HexToHash("0x01")
+ for i := 0; i < b.N; i++ {
+ bytes.TrimLeft(value[:], "\x00")
+ }
+}
+
+func BenchmarkCutsetterFn(b *testing.B) {
+ value := common.HexToHash("0x01")
+ cutSetFn := func(r rune) bool { return r == 0 }
+ for i := 0; i < b.N; i++ {
+ bytes.TrimLeftFunc(value[:], cutSetFn)
+ }
+}
+
+func BenchmarkCutCustomTrim(b *testing.B) {
+ value := common.HexToHash("0x01")
+ for i := 0; i < b.N; i++ {
+ common.TrimLeftZeroes(value[:])
+ }
+}
diff --git a/trie_by_cid/state/state_test.go b/trie_by_cid/state/state_test.go
new file mode 100644
index 0000000..6b3fa85
--- /dev/null
+++ b/trie_by_cid/state/state_test.go
@@ -0,0 +1,219 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "bytes"
+ "context"
+ "math/big"
+ "testing"
+
+ pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
+
+ "github.com/cerc-io/ipld-eth-statedb/internal"
+)
+
+var (
+ testCtx = context.Background()
+ testConfig, _ = postgres.DefaultConfig.WithEnv()
+ teardownStatements = []string{`TRUNCATE ipld.blocks`}
+)
+
+type stateTest struct {
+ db ethdb.Database
+ state *StateDB
+}
+
+func newStateTest(t *testing.T) *stateTest {
+ pool, err := postgres.ConnectSQLX(testCtx, testConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t))
+ sdb, err := New(common.Hash{}, NewDatabase(db), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return &stateTest{db: db, state: sdb}
+}
+
+func TestNull(t *testing.T) {
+ s := newStateTest(t)
+ address := common.HexToAddress("0x823140710bf13990e4500136726d8b55")
+ s.state.CreateAccount(address)
+ //value := common.FromHex("0x823140710bf13990e4500136726d8b55")
+ var value common.Hash
+
+ s.state.SetState(address, common.Hash{}, value)
+ // s.state.Commit(false)
+
+ if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) {
+ t.Errorf("expected empty current value, got %x", value)
+ }
+ if value := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) {
+ t.Errorf("expected empty committed value, got %x", value)
+ }
+}
+
+func TestSnapshot(t *testing.T) {
+ stateobjaddr := common.BytesToAddress([]byte("aa"))
+ var storageaddr common.Hash
+ data1 := common.BytesToHash([]byte{42})
+ data2 := common.BytesToHash([]byte{43})
+ s := newStateTest(t)
+
+ // snapshot the genesis state
+ genesis := s.state.Snapshot()
+
+ // set initial state object value
+ s.state.SetState(stateobjaddr, storageaddr, data1)
+ snapshot := s.state.Snapshot()
+
+ // set a new state object value, revert it and ensure correct content
+ s.state.SetState(stateobjaddr, storageaddr, data2)
+ s.state.RevertToSnapshot(snapshot)
+
+ if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 {
+ t.Errorf("wrong storage value %v, want %v", v, data1)
+ }
+ if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) {
+ t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{})
+ }
+
+ // revert up to the genesis state and ensure correct content
+ s.state.RevertToSnapshot(genesis)
+ if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) {
+ t.Errorf("wrong storage value %v, want %v", v, common.Hash{})
+ }
+ if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) {
+ t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{})
+ }
+}
+
+func TestSnapshotEmpty(t *testing.T) {
+ s := newStateTest(t)
+ s.state.RevertToSnapshot(s.state.Snapshot())
+}
+
+func TestSnapshot2(t *testing.T) {
+ state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+
+ stateobjaddr0 := common.BytesToAddress([]byte("so0"))
+ stateobjaddr1 := common.BytesToAddress([]byte("so1"))
+ var storageaddr common.Hash
+
+ data0 := common.BytesToHash([]byte{17})
+ data1 := common.BytesToHash([]byte{18})
+
+ state.SetState(stateobjaddr0, storageaddr, data0)
+ state.SetState(stateobjaddr1, storageaddr, data1)
+
+ // db, trie are already non-empty values
+ so0 := state.getStateObject(stateobjaddr0)
+ so0.SetBalance(big.NewInt(42))
+ so0.SetNonce(43)
+ so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'})
+ so0.suicided = false
+ so0.deleted = false
+ state.setStateObject(so0)
+
+ // root, _ := state.Commit(false)
+ // state, _ = New(root, state.db, state.snaps)
+
+ // and one with deleted == true
+ so1 := state.getStateObject(stateobjaddr1)
+ so1.SetBalance(big.NewInt(52))
+ so1.SetNonce(53)
+ so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'})
+ so1.suicided = true
+ so1.deleted = true
+ state.setStateObject(so1)
+
+ so1 = state.getStateObject(stateobjaddr1)
+ if so1 != nil {
+ t.Fatalf("deleted object not nil when getting")
+ }
+
+ snapshot := state.Snapshot()
+ state.RevertToSnapshot(snapshot)
+
+ so0Restored := state.getStateObject(stateobjaddr0)
+ // Update lazily-loaded values before comparing.
+ so0Restored.GetState(state.db, storageaddr)
+ so0Restored.Code(state.db)
+ // non-deleted is equal (restored)
+ compareStateObjects(so0Restored, so0, t)
+
+ // deleted should be nil, both before and after restore of state copy
+ so1Restored := state.getStateObject(stateobjaddr1)
+ if so1Restored != nil {
+ t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored)
+ }
+}
+
+func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
+ if so0.Address() != so1.Address() {
+ t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
+ }
+ if so0.Balance().Cmp(so1.Balance()) != 0 {
+ t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance())
+ }
+ if so0.Nonce() != so1.Nonce() {
+ t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce())
+ }
+ if so0.data.Root != so1.data.Root {
+ t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:])
+ }
+ if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) {
+ t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash())
+ }
+ if !bytes.Equal(so0.code, so1.code) {
+ t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
+ }
+
+ if len(so1.dirtyStorage) != len(so0.dirtyStorage) {
+ t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage))
+ }
+ for k, v := range so1.dirtyStorage {
+ if so0.dirtyStorage[k] != v {
+ t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v)
+ }
+ }
+ for k, v := range so0.dirtyStorage {
+ if so1.dirtyStorage[k] != v {
+ t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v)
+ }
+ }
+ if len(so1.originStorage) != len(so0.originStorage) {
+ t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage))
+ }
+ for k, v := range so1.originStorage {
+ if so0.originStorage[k] != v {
+ t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v)
+ }
+ }
+ for k, v := range so0.originStorage {
+ if so1.originStorage[k] != v {
+ t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v)
+ }
+ }
+}
diff --git a/trie_by_cid/state/statedb.go b/trie_by_cid/state/statedb.go
index fc0c0b8..b4c4369 100644
--- a/trie_by_cid/state/statedb.go
+++ b/trie_by_cid/state/statedb.go
@@ -1,17 +1,45 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package state provides a caching layer atop the Ethereum state trie.
package state
import (
"errors"
"fmt"
"math/big"
+ "sort"
"time"
"github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/metrics"
+ "github.com/ethereum/go-ethereum/params"
+ "github.com/ethereum/go-ethereum/rlp"
+
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
)
+type revision struct {
+ id int
+ journalIndex int
+}
+
type proofList [][]byte
func (n *proofList) Put(key []byte, value []byte) error {
@@ -28,46 +56,127 @@ func (n *proofList) Delete(key []byte) error {
// nested states. It's the general query interface to retrieve:
// * Contracts
// * Accounts
-//
-// This implementation is read-only and performs no journaling, prefetching, or metrics tracking.
type StateDB struct {
- db Database
- trie Trie
- hasher crypto.KeccakState
+ db Database
+ prefetcher *triePrefetcher
+ trie Trie
+ hasher crypto.KeccakState
+
+ // originalRoot is the pre-state root, before any changes were made.
+ // It will be updated when the Commit is called.
+ originalRoot common.Hash
+
+ snaps *snapshot.Tree
+ snap snapshot.Snapshot
+ snapAccounts map[common.Hash][]byte
+ snapStorage map[common.Hash]map[common.Hash][]byte
// This map holds 'live' objects, which will get modified while processing a state transition.
- stateObjects map[common.Address]*stateObject
+ stateObjects map[common.Address]*stateObject
+ stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie
+ stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution
+ stateObjectsDestruct map[common.Address]struct{} // State objects destructed in the block
// DB error.
// State objects are used by the consensus core and VM which are
// unable to deal with database-level errors. Any error that occurs
- // during a database read is memoized here and will eventually be returned
- // by StateDB.Commit.
+ // during a database read is memoized here and will eventually be
+ // returned by StateDB.Commit. Notably, this error is also shared
+ // by all cached state objects in case the database failure occurs
+ // when accessing state of accounts.
dbErr error
+ // The refund counter, also used by state transitioning.
+ refund uint64
+
+ thash common.Hash
+ txIndex int
+ logs map[common.Hash][]*types.Log
+ logSize uint
+
preimages map[common.Hash][]byte
+ // Per-transaction access list
+ accessList *accessList
+
+ // Transient storage
+ transientStorage transientStorage
+
+ // Journal of state modifications. This is the backbone of
+ // Snapshot and RevertToSnapshot.
+ journal *journal
+ validRevisions []revision
+ nextRevisionId int
+
// Measurements gathered during execution for debugging purposes
- AccountReads time.Duration
- StorageReads time.Duration
+ AccountReads time.Duration
+ AccountHashes time.Duration
+ AccountUpdates time.Duration
+ StorageReads time.Duration
+ StorageHashes time.Duration
+ StorageUpdates time.Duration
+ SnapshotAccountReads time.Duration
+ SnapshotStorageReads time.Duration
+
+ AccountUpdated int
+ StorageUpdated int
+ AccountDeleted int
+ StorageDeleted int
}
// New creates a new state from a given trie.
-func New(root common.Hash, db Database) (*StateDB, error) {
+func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) {
tr, err := db.OpenTrie(root)
if err != nil {
return nil, err
}
sdb := &StateDB{
- db: db,
- trie: tr,
- stateObjects: make(map[common.Address]*stateObject),
- preimages: make(map[common.Hash][]byte),
- hasher: crypto.NewKeccakState(),
+ db: db,
+ trie: tr,
+ originalRoot: root,
+ snaps: snaps,
+ stateObjects: make(map[common.Address]*stateObject),
+ stateObjectsPending: make(map[common.Address]struct{}),
+ stateObjectsDirty: make(map[common.Address]struct{}),
+ stateObjectsDestruct: make(map[common.Address]struct{}),
+ logs: make(map[common.Hash][]*types.Log),
+ preimages: make(map[common.Hash][]byte),
+ journal: newJournal(),
+ accessList: newAccessList(),
+ transientStorage: newTransientStorage(),
+ hasher: crypto.NewKeccakState(),
+ }
+ if sdb.snaps != nil {
+ if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil {
+ sdb.snapAccounts = make(map[common.Hash][]byte)
+ sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
+ }
}
return sdb, nil
}
+// StartPrefetcher initializes a new trie prefetcher to pull in nodes from the
+// state trie concurrently while the state is mutated so that when we reach the
+// commit phase, most of the needed data is already hot.
+func (s *StateDB) StartPrefetcher(namespace string) {
+ if s.prefetcher != nil {
+ s.prefetcher.close()
+ s.prefetcher = nil
+ }
+ if s.snap != nil {
+ s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace)
+ }
+}
+
+// StopPrefetcher terminates a running prefetcher and reports any leftover stats
+// from the gathered metrics.
+func (s *StateDB) StopPrefetcher() {
+ if s.prefetcher != nil {
+ s.prefetcher.close()
+ s.prefetcher = nil
+ }
+}
+
// setError remembers the first non-nil error it is called with.
func (s *StateDB) setError(err error) {
if s.dbErr == nil {
@@ -75,17 +184,44 @@ func (s *StateDB) setError(err error) {
}
}
+// Error returns the memorized database failure occurred earlier.
func (s *StateDB) Error() error {
return s.dbErr
}
func (s *StateDB) AddLog(log *types.Log) {
- panic("unsupported")
+ s.journal.append(addLogChange{txhash: s.thash})
+
+ log.TxHash = s.thash
+ log.TxIndex = uint(s.txIndex)
+ log.Index = s.logSize
+ s.logs[s.thash] = append(s.logs[s.thash], log)
+ s.logSize++
+}
+
+// GetLogs returns the logs matching the specified transaction hash, and annotates
+// them with the given blockNumber and blockHash.
+func (s *StateDB) GetLogs(hash common.Hash, blockNumber uint64, blockHash common.Hash) []*types.Log {
+ logs := s.logs[hash]
+ for _, l := range logs {
+ l.BlockNumber = blockNumber
+ l.BlockHash = blockHash
+ }
+ return logs
+}
+
+func (s *StateDB) Logs() []*types.Log {
+ var logs []*types.Log
+ for _, lgs := range s.logs {
+ logs = append(logs, lgs...)
+ }
+ return logs
}
// AddPreimage records a SHA3 preimage seen by the VM.
func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) {
if _, ok := s.preimages[hash]; !ok {
+ s.journal.append(addPreimageChange{hash: hash})
pi := make([]byte, len(preimage))
copy(pi, preimage)
s.preimages[hash] = pi
@@ -99,13 +235,18 @@ func (s *StateDB) Preimages() map[common.Hash][]byte {
// AddRefund adds gas to the refund counter
func (s *StateDB) AddRefund(gas uint64) {
- panic("unsupported")
+ s.journal.append(refundChange{prev: s.refund})
+ s.refund += gas
}
// SubRefund removes gas from the refund counter.
// This method will panic if the refund counter goes below zero
func (s *StateDB) SubRefund(gas uint64) {
- panic("unsupported")
+ s.journal.append(refundChange{prev: s.refund})
+ if gas > s.refund {
+ panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund))
+ }
+ s.refund -= gas
}
// Exist reports whether the given account address exists in the state.
@@ -139,6 +280,11 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 {
return 0
}
+// TxIndex returns the current transaction index set by Prepare.
+func (s *StateDB) TxIndex() int {
+ return s.txIndex
+}
+
func (s *StateDB) GetCode(addr common.Address) []byte {
stateObject := s.getStateObject(addr)
if stateObject != nil {
@@ -186,18 +332,28 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) {
// GetStorageProof returns the Merkle proof for given storage slot.
func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) {
- var proof proofList
- trie := s.StorageTrie(a)
- if trie == nil {
- return proof, errors.New("storage trie for requested address does not exist")
+ trie, err := s.StorageTrie(a)
+ if err != nil {
+ return nil, err
}
- err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
- return proof, err
+ if trie == nil {
+ return nil, errors.New("storage trie for requested address does not exist")
+ }
+ var proof proofList
+ err = trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof)
+ if err != nil {
+ return nil, err
+ }
+ return proof, nil
}
// GetCommittedState retrieves a value from the given account's committed storage trie.
func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash {
- return s.GetState(addr, hash)
+ stateObject := s.getStateObject(addr)
+ if stateObject != nil {
+ return stateObject.GetCommittedState(s.db, hash)
+ }
+ return common.Hash{}
}
// Database retrieves the low level database supporting the lower level trie ops.
@@ -205,17 +361,26 @@ func (s *StateDB) Database() Database {
return s.db
}
-// StorageTrie returns the storage trie of an account.
-// The return value is a copy and is nil for non-existent accounts.
-func (s *StateDB) StorageTrie(addr common.Address) Trie {
+// StorageTrie returns the storage trie of an account. The return value is a copy
+// and is nil for non-existent accounts. An error will be returned if storage trie
+// is existent but can't be loaded correctly.
+func (s *StateDB) StorageTrie(addr common.Address) (Trie, error) {
stateObject := s.getStateObject(addr)
if stateObject == nil {
- return nil
+ return nil, nil
}
- return stateObject.getTrie(s.db)
+ cpy := stateObject.deepCopy(s)
+ if _, err := cpy.updateTrie(s.db); err != nil {
+ return nil, err
+ }
+ return cpy.getTrie(s.db)
}
func (s *StateDB) HasSuicided(addr common.Address) bool {
+ stateObject := s.getStateObject(addr)
+ if stateObject != nil {
+ return stateObject.suicided
+ }
return false
}
@@ -225,34 +390,61 @@ func (s *StateDB) HasSuicided(addr common.Address) bool {
// AddBalance adds amount to the account associated with addr.
func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
- panic("unsupported")
+ stateObject := s.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.AddBalance(amount)
+ }
}
// SubBalance subtracts amount from the account associated with addr.
func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) {
- panic("unsupported")
+ stateObject := s.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.SubBalance(amount)
+ }
}
func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) {
- panic("unsupported")
+ stateObject := s.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.SetBalance(amount)
+ }
}
func (s *StateDB) SetNonce(addr common.Address, nonce uint64) {
- panic("unsupported")
+ stateObject := s.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.SetNonce(nonce)
+ }
}
func (s *StateDB) SetCode(addr common.Address, code []byte) {
- panic("unsupported")
+ stateObject := s.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.SetCode(crypto.Keccak256Hash(code), code)
+ }
}
func (s *StateDB) SetState(addr common.Address, key, value common.Hash) {
- panic("unsupported")
+ stateObject := s.GetOrNewStateObject(addr)
+ if stateObject != nil {
+ stateObject.SetState(s.db, key, value)
+ }
}
// SetStorage replaces the entire storage for the specified account with given
// storage. This function should only be used for debugging.
func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) {
- panic("unsupported")
+ // SetStorage needs to wipe existing storage. We achieve this by pretending
+ // that the account self-destructed earlier in this block, by flagging
+ // it in stateObjectsDestruct. The effect of doing so is that storage lookups
+ // will not hit disk, since it is assumed that the disk-data is belonging
+ // to a previous incarnation of the object.
+ s.stateObjectsDestruct[addr] = struct{}{}
+ stateObject := s.GetOrNewStateObject(addr)
+ for k, v := range storage {
+ stateObject.SetState(s.db, k, v)
+ }
}
// Suicide marks the given account as suicided.
@@ -261,36 +453,147 @@ func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common
// The account's state object is still available until the state is committed,
// getStateObject will return a non-nil account after Suicide.
func (s *StateDB) Suicide(addr common.Address) bool {
- panic("unsupported")
- return false
+ stateObject := s.getStateObject(addr)
+ if stateObject == nil {
+ return false
+ }
+ s.journal.append(suicideChange{
+ account: &addr,
+ prev: stateObject.suicided,
+ prevbalance: new(big.Int).Set(stateObject.Balance()),
+ })
+ stateObject.markSuicided()
+ stateObject.data.Balance = new(big.Int)
+
+ return true
+}
+
+// SetTransientState sets transient storage for a given account. It
+// adds the change to the journal so that it can be rolled back
+// to its previous value if there is a revert.
+func (s *StateDB) SetTransientState(addr common.Address, key, value common.Hash) {
+ prev := s.GetTransientState(addr, key)
+ if prev == value {
+ return
+ }
+ s.journal.append(transientStorageChange{
+ account: &addr,
+ key: key,
+ prevalue: prev,
+ })
+ s.setTransientState(addr, key, value)
+}
+
+// setTransientState is a lower level setter for transient storage. It
+// is called during a revert to prevent modifications to the journal.
+func (s *StateDB) setTransientState(addr common.Address, key, value common.Hash) {
+ s.transientStorage.Set(addr, key, value)
+}
+
+// GetTransientState gets transient storage for a given account.
+func (s *StateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash {
+ return s.transientStorage.Get(addr, key)
}
//
// Setting, updating & deleting state object methods.
//
+// updateStateObject writes the given object to the trie.
+func (s *StateDB) updateStateObject(obj *stateObject) {
+ // Track the amount of time wasted on updating the account from the trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now())
+ }
+ // Encode the account and update the account trie
+ addr := obj.Address()
+ if err := s.trie.TryUpdateAccount(addr, &obj.data); err != nil {
+ s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err))
+ }
+
+ // If state snapshotting is active, cache the data til commit. Note, this
+ // update mechanism is not symmetric to the deletion, because whereas it is
+ // enough to track account updates at commit time, deletions need tracking
+ // at transaction boundary level to ensure we capture state clearing.
+ if s.snap != nil {
+ s.snapAccounts[obj.addrHash] = snapshot.SlimAccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash)
+ }
+}
+
+// deleteStateObject removes the given object from the state trie.
+func (s *StateDB) deleteStateObject(obj *stateObject) {
+ // Track the amount of time wasted on deleting the account from the trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now())
+ }
+ // Delete the account from the trie
+ addr := obj.Address()
+ if err := s.trie.TryDeleteAccount(addr); err != nil {
+ s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err))
+ }
+}
+
// getStateObject retrieves a state object given by the address, returning nil if
-// the object is not found or was deleted in this execution context.
+// the object is not found or was deleted in this execution context. If you need
+// to differentiate between non-existent/just-deleted, use getDeletedStateObject.
func (s *StateDB) getStateObject(addr common.Address) *stateObject {
+ if obj := s.getDeletedStateObject(addr); obj != nil && !obj.deleted {
+ return obj
+ }
+ return nil
+}
+
+// getDeletedStateObject is similar to getStateObject, but instead of returning
+// nil for a deleted state object, it returns the actual object with the deleted
+// flag set. This is needed by the state journal to revert to the correct s-
+// destructed object instead of wiping all knowledge about the state object.
+func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
// Prefer live objects if any is available
if obj := s.stateObjects[addr]; obj != nil {
return obj
}
- // If no live objects are available, load from the database
- start := time.Now()
- var err error
- data, err := s.trie.TryGetAccount(addr.Bytes())
- if metrics.EnabledExpensive {
- s.AccountReads += time.Since(start)
- }
- if err != nil {
- s.setError(fmt.Errorf("getStateObject (%x) error: %w", addr.Bytes(), err))
- return nil
+ // If no live objects are available, attempt to use snapshots
+ var data *types.StateAccount
+ if s.snap != nil {
+ start := time.Now()
+ acc, err := s.snap.Account(crypto.HashData(s.hasher, addr.Bytes()))
+ if metrics.EnabledExpensive {
+ s.SnapshotAccountReads += time.Since(start)
+ }
+ if err == nil {
+ if acc == nil {
+ return nil
+ }
+ data = &types.StateAccount{
+ Nonce: acc.Nonce,
+ Balance: acc.Balance,
+ CodeHash: acc.CodeHash,
+ Root: common.BytesToHash(acc.Root),
+ }
+ if len(data.CodeHash) == 0 {
+ data.CodeHash = types.EmptyCodeHash.Bytes()
+ }
+ if data.Root == (common.Hash{}) {
+ data.Root = types.EmptyRootHash
+ }
+ }
}
+ // If snapshot unavailable or reading from it failed, load from the database
if data == nil {
- return nil
+ start := time.Now()
+ var err error
+ data, err = s.trie.TryGetAccount(addr)
+ if metrics.EnabledExpensive {
+ s.AccountReads += time.Since(start)
+ }
+ if err != nil {
+ s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err))
+ return nil
+ }
+ if data == nil {
+ return nil
+ }
}
-
// Insert into the live set
obj := newObject(s, addr, *data)
s.setStateObject(obj)
@@ -301,69 +604,439 @@ func (s *StateDB) setStateObject(object *stateObject) {
s.stateObjects[object.Address()] = object
}
+// GetOrNewStateObject retrieves a state object or create a new state object if nil.
+func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject {
+ stateObject := s.getStateObject(addr)
+ if stateObject == nil {
+ stateObject, _ = s.createObject(addr)
+ }
+ return stateObject
+}
+
+// createObject creates a new state object. If there is an existing account with
+// the given address, it is overwritten and returned as the second return value.
+func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) {
+ prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that!
+
+ var prevdestruct bool
+ if prev != nil {
+ _, prevdestruct = s.stateObjectsDestruct[prev.address]
+ if !prevdestruct {
+ s.stateObjectsDestruct[prev.address] = struct{}{}
+ }
+ }
+ newobj = newObject(s, addr, types.StateAccount{})
+ if prev == nil {
+ s.journal.append(createObjectChange{account: &addr})
+ } else {
+ s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct})
+ }
+ s.setStateObject(newobj)
+ if prev != nil && !prev.deleted {
+ return newobj, prev
+ }
+ return newobj, nil
+}
+
// CreateAccount explicitly creates a state object. If a state object with the address
// already exists the balance is carried over to the new account.
//
// CreateAccount is called during the EVM CREATE operation. The situation might arise that
// a contract does the following:
//
-// 1. sends funds to sha(account ++ (nonce + 1))
-// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
+// 1. sends funds to sha(account ++ (nonce + 1))
+// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1)
//
// Carrying over the balance ensures that Ether doesn't disappear.
func (s *StateDB) CreateAccount(addr common.Address) {
- panic("unsupported")
+ newObj, prev := s.createObject(addr)
+ if prev != nil {
+ newObj.setBalance(prev.data.Balance)
+ }
}
func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error {
+ so := db.getStateObject(addr)
+ if so == nil {
+ return nil
+ }
+ tr, err := so.getTrie(db.db)
+ if err != nil {
+ return err
+ }
+ it := trie.NewIterator(tr.NodeIterator(nil))
+
+ for it.Next() {
+ key := common.BytesToHash(db.trie.GetKey(it.Key))
+ if value, dirty := so.dirtyStorage[key]; dirty {
+ if !cb(key, value) {
+ return nil
+ }
+ continue
+ }
+
+ if len(it.Value) > 0 {
+ _, content, _, err := rlp.Split(it.Value)
+ if err != nil {
+ return err
+ }
+ if !cb(key, common.BytesToHash(content)) {
+ return nil
+ }
+ }
+ }
return nil
}
+// Copy creates a deep, independent copy of the state.
+// Snapshots of the copied state cannot be applied to the copy.
+func (s *StateDB) Copy() *StateDB {
+ // Copy all the basic fields, initialize the memory ones
+ state := &StateDB{
+ db: s.db,
+ trie: s.db.CopyTrie(s.trie),
+ originalRoot: s.originalRoot,
+ stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)),
+ stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)),
+ stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)),
+ stateObjectsDestruct: make(map[common.Address]struct{}, len(s.stateObjectsDestruct)),
+ refund: s.refund,
+ logs: make(map[common.Hash][]*types.Log, len(s.logs)),
+ logSize: s.logSize,
+ preimages: make(map[common.Hash][]byte, len(s.preimages)),
+ journal: newJournal(),
+ hasher: crypto.NewKeccakState(),
+ }
+ // Copy the dirty states, logs, and preimages
+ for addr := range s.journal.dirties {
+ // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527),
+ // and in the Finalise-method, there is a case where an object is in the journal but not
+ // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for
+ // nil
+ if object, exist := s.stateObjects[addr]; exist {
+ // Even though the original object is dirty, we are not copying the journal,
+ // so we need to make sure that any side-effect the journal would have caused
+ // during a commit (or similar op) is already applied to the copy.
+ state.stateObjects[addr] = object.deepCopy(state)
+
+ state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits
+ state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits
+ }
+ }
+ // Above, we don't copy the actual journal. This means that if the copy
+ // is copied, the loop above will be a no-op, since the copy's journal
+ // is empty. Thus, here we iterate over stateObjects, to enable copies
+ // of copies.
+ for addr := range s.stateObjectsPending {
+ if _, exist := state.stateObjects[addr]; !exist {
+ state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state)
+ }
+ state.stateObjectsPending[addr] = struct{}{}
+ }
+ for addr := range s.stateObjectsDirty {
+ if _, exist := state.stateObjects[addr]; !exist {
+ state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state)
+ }
+ state.stateObjectsDirty[addr] = struct{}{}
+ }
+ // Deep copy the destruction flag.
+ for addr := range s.stateObjectsDestruct {
+ state.stateObjectsDestruct[addr] = struct{}{}
+ }
+ for hash, logs := range s.logs {
+ cpy := make([]*types.Log, len(logs))
+ for i, l := range logs {
+ cpy[i] = new(types.Log)
+ *cpy[i] = *l
+ }
+ state.logs[hash] = cpy
+ }
+ for hash, preimage := range s.preimages {
+ state.preimages[hash] = preimage
+ }
+ // Do we need to copy the access list and transient storage?
+ // In practice: No. At the start of a transaction, these two lists are empty.
+ // In practice, we only ever copy state _between_ transactions/blocks, never
+ // in the middle of a transaction. However, it doesn't cost us much to copy
+ // empty lists, so we do it anyway to not blow up if we ever decide copy them
+ // in the middle of a transaction.
+ state.accessList = s.accessList.Copy()
+ state.transientStorage = s.transientStorage.Copy()
+
+ // If there's a prefetcher running, make an inactive copy of it that can
+ // only access data but does not actively preload (since the user will not
+ // know that they need to explicitly terminate an active copy).
+ if s.prefetcher != nil {
+ state.prefetcher = s.prefetcher.copy()
+ }
+ if s.snaps != nil {
+ // In order for the miner to be able to use and make additions
+ // to the snapshot tree, we need to copy that as well.
+ // Otherwise, any block mined by ourselves will cause gaps in the tree,
+ // and force the miner to operate trie-backed only
+ state.snaps = s.snaps
+ state.snap = s.snap
+
+ // deep copy needed
+ state.snapAccounts = make(map[common.Hash][]byte)
+ for k, v := range s.snapAccounts {
+ state.snapAccounts[k] = v
+ }
+ state.snapStorage = make(map[common.Hash]map[common.Hash][]byte)
+ for k, v := range s.snapStorage {
+ temp := make(map[common.Hash][]byte)
+ for kk, vv := range v {
+ temp[kk] = vv
+ }
+ state.snapStorage[k] = temp
+ }
+ }
+ return state
+}
+
// Snapshot returns an identifier for the current revision of the state.
func (s *StateDB) Snapshot() int {
- return 0
+ id := s.nextRevisionId
+ s.nextRevisionId++
+ s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()})
+ return id
}
// RevertToSnapshot reverts all state changes made since the given revision.
func (s *StateDB) RevertToSnapshot(revid int) {
- panic("unsupported")
+ // Find the snapshot in the stack of valid snapshots.
+ idx := sort.Search(len(s.validRevisions), func(i int) bool {
+ return s.validRevisions[i].id >= revid
+ })
+ if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid {
+ panic(fmt.Errorf("revision id %v cannot be reverted", revid))
+ }
+ snapshot := s.validRevisions[idx].journalIndex
+
+ // Replay the journal to undo changes and remove invalidated snapshots
+ s.journal.revert(s, snapshot)
+ s.validRevisions = s.validRevisions[:idx]
}
// GetRefund returns the current value of the refund counter.
func (s *StateDB) GetRefund() uint64 {
- panic("unsupported")
- return 0
+ return s.refund
}
-// PrepareAccessList handles the preparatory steps for executing a state transition with
-// regards to both EIP-2929 and EIP-2930:
+// Finalise finalises the state by removing the destructed objects and clears
+// the journal as well as the refunds. Finalise, however, will not push any updates
+// into the tries just yet. Only IntermediateRoot or Commit will do that.
+func (s *StateDB) Finalise(deleteEmptyObjects bool) {
+ addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties))
+ for addr := range s.journal.dirties {
+ obj, exist := s.stateObjects[addr]
+ if !exist {
+ // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2
+ // That tx goes out of gas, and although the notion of 'touched' does not exist there, the
+ // touch-event will still be recorded in the journal. Since ripeMD is a special snowflake,
+ // it will persist in the journal even though the journal is reverted. In this special circumstance,
+ // it may exist in `s.journal.dirties` but not in `s.stateObjects`.
+ // Thus, we can safely ignore it here
+ continue
+ }
+ if obj.suicided || (deleteEmptyObjects && obj.empty()) {
+ obj.deleted = true
+
+ // We need to maintain account deletions explicitly (will remain
+ // set indefinitely).
+ s.stateObjectsDestruct[obj.address] = struct{}{}
+
+ // If state snapshotting is active, also mark the destruction there.
+ // Note, we can't do this only at the end of a block because multiple
+ // transactions within the same block might self destruct and then
+ // resurrect an account; but the snapshotter needs both events.
+ if s.snap != nil {
+ delete(s.snapAccounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect)
+ delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect)
+ }
+ } else {
+ obj.finalise(true) // Prefetch slots in the background
+ }
+ s.stateObjectsPending[addr] = struct{}{}
+ s.stateObjectsDirty[addr] = struct{}{}
+
+ // At this point, also ship the address off to the precacher. The precacher
+ // will start loading tries, and when the change is eventually committed,
+ // the commit-phase will be a lot faster
+ addressesToPrefetch = append(addressesToPrefetch, common.CopyBytes(addr[:])) // Copy needed for closure
+ }
+ if s.prefetcher != nil && len(addressesToPrefetch) > 0 {
+ s.prefetcher.prefetch(common.Hash{}, s.originalRoot, addressesToPrefetch)
+ }
+ // Invalidate journal because reverting across transactions is not allowed.
+ s.clearJournalAndRefund()
+}
+
+// IntermediateRoot computes the current root hash of the state trie.
+// It is called in between transactions to get the root hash that
+// goes into transaction receipts.
+func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
+ // Finalise all the dirty storage states and write them into the tries
+ s.Finalise(deleteEmptyObjects)
+
+ // If there was a trie prefetcher operating, it gets aborted and irrevocably
+ // modified after we start retrieving tries. Remove it from the statedb after
+ // this round of use.
+ //
+ // This is weird pre-byzantium since the first tx runs with a prefetcher and
+ // the remainder without, but pre-byzantium even the initial prefetcher is
+ // useless, so no sleep lost.
+ prefetcher := s.prefetcher
+ if s.prefetcher != nil {
+ defer func() {
+ s.prefetcher.close()
+ s.prefetcher = nil
+ }()
+ }
+ // Although naively it makes sense to retrieve the account trie and then do
+ // the contract storage and account updates sequentially, that short circuits
+ // the account prefetcher. Instead, let's process all the storage updates
+ // first, giving the account prefetches just a few more milliseconds of time
+ // to pull useful data from disk.
+ for addr := range s.stateObjectsPending {
+ if obj := s.stateObjects[addr]; !obj.deleted {
+ obj.updateRoot(s.db)
+ }
+ }
+ // Now we're about to start to write changes to the trie. The trie is so far
+ // _untouched_. We can check with the prefetcher, if it can give us a trie
+ // which has the same root, but also has some content loaded into it.
+ if prefetcher != nil {
+ if trie := prefetcher.trie(common.Hash{}, s.originalRoot); trie != nil {
+ s.trie = trie
+ }
+ }
+ usedAddrs := make([][]byte, 0, len(s.stateObjectsPending))
+ for addr := range s.stateObjectsPending {
+ if obj := s.stateObjects[addr]; obj.deleted {
+ s.deleteStateObject(obj)
+ s.AccountDeleted += 1
+ } else {
+ s.updateStateObject(obj)
+ s.AccountUpdated += 1
+ }
+ usedAddrs = append(usedAddrs, common.CopyBytes(addr[:])) // Copy needed for closure
+ }
+ if prefetcher != nil {
+ prefetcher.used(common.Hash{}, s.originalRoot, usedAddrs)
+ }
+ if len(s.stateObjectsPending) > 0 {
+ s.stateObjectsPending = make(map[common.Address]struct{})
+ }
+ // Track the amount of time wasted on hashing the account trie
+ if metrics.EnabledExpensive {
+ defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now())
+ }
+ return s.trie.Hash()
+}
+
+// SetTxContext sets the current transaction hash and index which are
+// used when the EVM emits new state logs. It should be invoked before
+// transaction execution.
+func (s *StateDB) SetTxContext(thash common.Hash, ti int) {
+ s.thash = thash
+ s.txIndex = ti
+}
+
+func (s *StateDB) clearJournalAndRefund() {
+ if len(s.journal.entries) > 0 {
+ s.journal = newJournal()
+ s.refund = 0
+ }
+ s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries
+}
+
+// Prepare handles the preparatory steps for executing a state transition with.
+// This method must be invoked before state transition.
//
+// Berlin fork:
// - Add sender to access list (2929)
// - Add destination to access list (2929)
// - Add precompiles to access list (2929)
// - Add the contents of the optional tx access list (2930)
//
-// This method should only be called if Berlin/2929+2930 is applicable at the current number.
-func (s *StateDB) PrepareAccessList(sender common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) {
- panic("unsupported")
+// Potential EIPs:
+// - Reset access list (Berlin)
+// - Add coinbase to access list (EIP-3651)
+// - Reset transient storage (EIP-1153)
+func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) {
+ if rules.IsBerlin {
+ // Clear out any leftover from previous executions
+ al := newAccessList()
+ s.accessList = al
+
+ al.AddAddress(sender)
+ if dst != nil {
+ al.AddAddress(*dst)
+ // If it's a create-tx, the destination will be added inside evm.create
+ }
+ for _, addr := range precompiles {
+ al.AddAddress(addr)
+ }
+ for _, el := range list {
+ al.AddAddress(el.Address)
+ for _, key := range el.StorageKeys {
+ al.AddSlot(el.Address, key)
+ }
+ }
+ if rules.IsShanghai { // EIP-3651: warm coinbase
+ al.AddAddress(coinbase)
+ }
+ }
+ // Reset transient storage at the beginning of transaction execution
+ s.transientStorage = newTransientStorage()
}
// AddAddressToAccessList adds the given address to the access list
func (s *StateDB) AddAddressToAccessList(addr common.Address) {
- panic("unsupported")
+ if s.accessList.AddAddress(addr) {
+ s.journal.append(accessListAddAccountChange{&addr})
+ }
}
// AddSlotToAccessList adds the given (address, slot)-tuple to the access list
func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) {
- panic("unsupported")
+ addrMod, slotMod := s.accessList.AddSlot(addr, slot)
+ if addrMod {
+ // In practice, this should not happen, since there is no way to enter the
+ // scope of 'address' without having the 'address' become already added
+ // to the access list (via call-variant, create, etc).
+ // Better safe than sorry, though
+ s.journal.append(accessListAddAccountChange{&addr})
+ }
+ if slotMod {
+ s.journal.append(accessListAddSlotChange{
+ address: &addr,
+ slot: &slot,
+ })
+ }
}
// AddressInAccessList returns true if the given address is in the access list.
func (s *StateDB) AddressInAccessList(addr common.Address) bool {
- return false
+ return s.accessList.ContainsAddress(addr)
}
// SlotInAccessList returns true if the given (address, slot)-tuple is in the access list.
func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) {
- return
+ return s.accessList.Contains(addr, slot)
+}
+
+// convertAccountSet converts a provided account set from address keyed to hash keyed.
+func (s *StateDB) convertAccountSet(set map[common.Address]struct{}) map[common.Hash]struct{} {
+ ret := make(map[common.Hash]struct{})
+ for addr := range set {
+ obj, exist := s.stateObjects[addr]
+ if !exist {
+ ret[crypto.Keccak256Hash(addr[:])] = struct{}{}
+ } else {
+ ret[obj.addrHash] = struct{}{}
+ }
+ }
+ return ret
}
diff --git a/trie_by_cid/state/statedb_test.go b/trie_by_cid/state/statedb_test.go
new file mode 100644
index 0000000..60cd66a
--- /dev/null
+++ b/trie_by_cid/state/statedb_test.go
@@ -0,0 +1,594 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/big"
+ "math/rand"
+ "reflect"
+ "strings"
+ "sync"
+ "testing"
+ "testing/quick"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
+)
+
+// TestCopy tests that copying a StateDB object indeed makes the original and
+// the copy independent of each other. This test is a regression test against
+// https://github.com/ethereum/go-ethereum/pull/15549.
+func TestCopy(t *testing.T) {
+ // Create a random state test to copy and modify "independently"
+ orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+
+ for i := byte(0); i < 255; i++ {
+ obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+ obj.AddBalance(big.NewInt(int64(i)))
+ orig.updateStateObject(obj)
+ }
+ orig.Finalise(false)
+
+ // Copy the state
+ copy := orig.Copy()
+
+ // Copy the copy state
+ ccopy := copy.Copy()
+
+ // modify all in memory
+ for i := byte(0); i < 255; i++ {
+ origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+ copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+ ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+
+ origObj.AddBalance(big.NewInt(2 * int64(i)))
+ copyObj.AddBalance(big.NewInt(3 * int64(i)))
+ ccopyObj.AddBalance(big.NewInt(4 * int64(i)))
+
+ orig.updateStateObject(origObj)
+ copy.updateStateObject(copyObj)
+ ccopy.updateStateObject(copyObj)
+ }
+
+ // Finalise the changes on all concurrently
+ finalise := func(wg *sync.WaitGroup, db *StateDB) {
+ defer wg.Done()
+ db.Finalise(true)
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(3)
+ go finalise(&wg, orig)
+ go finalise(&wg, copy)
+ go finalise(&wg, ccopy)
+ wg.Wait()
+
+ // Verify that the three states have been updated independently
+ for i := byte(0); i < 255; i++ {
+ origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+ copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+ ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
+
+ if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
+ t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
+ }
+ if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
+ t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
+ }
+ if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 {
+ t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want)
+ }
+ }
+}
+
+func TestSnapshotRandom(t *testing.T) {
+ config := &quick.Config{MaxCount: 1000}
+ err := quick.Check((*snapshotTest).run, config)
+ if cerr, ok := err.(*quick.CheckError); ok {
+ test := cerr.In[0].(*snapshotTest)
+ t.Errorf("%v:\n%s", test.err, test)
+ } else if err != nil {
+ t.Error(err)
+ }
+}
+
+// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
+// captured by the snapshot. Instances of this test with pseudorandom content are created
+// by Generate.
+//
+// The test works as follows:
+//
+// A new state is created and all actions are applied to it. Several snapshots are taken
+// in between actions. The test then reverts each snapshot. For each snapshot the actions
+// leading up to it are replayed on a fresh, empty state. The behaviour of all public
+// accessor methods on the reverted state must match the return value of the equivalent
+// methods on the replayed state.
+type snapshotTest struct {
+ addrs []common.Address // all account addresses
+ actions []testAction // modifications to the state
+ snapshots []int // actions indexes at which snapshot is taken
+ err error // failure details are reported through this field
+}
+
+type testAction struct {
+ name string
+ fn func(testAction, *StateDB)
+ args []int64
+ noAddr bool
+}
+
+// newTestAction creates a random action that changes state.
+func newTestAction(addr common.Address, r *rand.Rand) testAction {
+ actions := []testAction{
+ {
+ name: "SetBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.SetBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "AddBalance",
+ fn: func(a testAction, s *StateDB) {
+ s.AddBalance(addr, big.NewInt(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetNonce",
+ fn: func(a testAction, s *StateDB) {
+ s.SetNonce(addr, uint64(a.args[0]))
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetState",
+ fn: func(a testAction, s *StateDB) {
+ var key, val common.Hash
+ binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
+ binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
+ s.SetState(addr, key, val)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "SetCode",
+ fn: func(a testAction, s *StateDB) {
+ code := make([]byte, 16)
+ binary.BigEndian.PutUint64(code, uint64(a.args[0]))
+ binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
+ s.SetCode(addr, code)
+ },
+ args: make([]int64, 2),
+ },
+ {
+ name: "CreateAccount",
+ fn: func(a testAction, s *StateDB) {
+ s.CreateAccount(addr)
+ },
+ },
+ {
+ name: "Suicide",
+ fn: func(a testAction, s *StateDB) {
+ s.Suicide(addr)
+ },
+ },
+ {
+ name: "AddRefund",
+ fn: func(a testAction, s *StateDB) {
+ s.AddRefund(uint64(a.args[0]))
+ },
+ args: make([]int64, 1),
+ noAddr: true,
+ },
+ {
+ name: "AddLog",
+ fn: func(a testAction, s *StateDB) {
+ data := make([]byte, 2)
+ binary.BigEndian.PutUint16(data, uint16(a.args[0]))
+ s.AddLog(&types.Log{Address: addr, Data: data})
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "AddPreimage",
+ fn: func(a testAction, s *StateDB) {
+ preimage := []byte{1}
+ hash := common.BytesToHash(preimage)
+ s.AddPreimage(hash, preimage)
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "AddAddressToAccessList",
+ fn: func(a testAction, s *StateDB) {
+ s.AddAddressToAccessList(addr)
+ },
+ },
+ {
+ name: "AddSlotToAccessList",
+ fn: func(a testAction, s *StateDB) {
+ s.AddSlotToAccessList(addr,
+ common.Hash{byte(a.args[0])})
+ },
+ args: make([]int64, 1),
+ },
+ {
+ name: "SetTransientState",
+ fn: func(a testAction, s *StateDB) {
+ var key, val common.Hash
+ binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
+ binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
+ s.SetTransientState(addr, key, val)
+ },
+ args: make([]int64, 2),
+ },
+ }
+ action := actions[r.Intn(len(actions))]
+ var nameargs []string
+ if !action.noAddr {
+ nameargs = append(nameargs, addr.Hex())
+ }
+ for i := range action.args {
+ action.args[i] = rand.Int63n(100)
+ nameargs = append(nameargs, fmt.Sprint(action.args[i]))
+ }
+ action.name += strings.Join(nameargs, ", ")
+ return action
+}
+
+// Generate returns a new snapshot test of the given size. All randomness is
+// derived from r.
+func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
+ // Generate random actions.
+ addrs := make([]common.Address, 50)
+ for i := range addrs {
+ addrs[i][0] = byte(i)
+ }
+ actions := make([]testAction, size)
+ for i := range actions {
+ addr := addrs[r.Intn(len(addrs))]
+ actions[i] = newTestAction(addr, r)
+ }
+ // Generate snapshot indexes.
+ nsnapshots := int(math.Sqrt(float64(size)))
+ if size > 0 && nsnapshots == 0 {
+ nsnapshots = 1
+ }
+ snapshots := make([]int, nsnapshots)
+ snaplen := len(actions) / nsnapshots
+ for i := range snapshots {
+ // Try to place the snapshots some number of actions apart from each other.
+ snapshots[i] = (i * snaplen) + r.Intn(snaplen)
+ }
+ return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
+}
+
+func (test *snapshotTest) String() string {
+ out := new(bytes.Buffer)
+ sindex := 0
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
+ sindex++
+ }
+ fmt.Fprintf(out, "%4d: %s\n", i, action.name)
+ }
+ return out.String()
+}
+
+func (test *snapshotTest) run() bool {
+ // Run all actions and create snapshots.
+ var (
+ state, _ = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+ snapshotRevs = make([]int, len(test.snapshots))
+ sindex = 0
+ )
+ for i, action := range test.actions {
+ if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
+ snapshotRevs[sindex] = state.Snapshot()
+ sindex++
+ }
+ action.fn(action, state)
+ }
+ // Revert all snapshots in reverse order. Each revert must yield a state
+ // that is equivalent to fresh state with all actions up the snapshot applied.
+ for sindex--; sindex >= 0; sindex-- {
+ checkstate, _ := New(common.Hash{}, state.Database(), nil)
+ for _, action := range test.actions[:test.snapshots[sindex]] {
+ action.fn(action, checkstate)
+ }
+ state.RevertToSnapshot(snapshotRevs[sindex])
+ if err := test.checkEqual(state, checkstate); err != nil {
+ test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
+ return false
+ }
+ }
+ return true
+}
+
+// checkEqual checks that methods of state and checkstate return the same values.
+func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
+ for _, addr := range test.addrs {
+ var err error
+ checkeq := func(op string, a, b interface{}) bool {
+ if err == nil && !reflect.DeepEqual(a, b) {
+ err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
+ return false
+ }
+ return true
+ }
+ // Check basic accessor methods.
+ checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
+ checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
+ checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
+ checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
+ checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
+ checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
+ checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
+ // Check storage.
+ if obj := state.getStateObject(addr); obj != nil {
+ state.ForEachStorage(addr, func(key, value common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
+ })
+ checkstate.ForEachStorage(addr, func(key, value common.Hash) bool {
+ return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
+ })
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if state.GetRefund() != checkstate.GetRefund() {
+ return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
+ state.GetRefund(), checkstate.GetRefund())
+ }
+ if !reflect.DeepEqual(state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) {
+ return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
+ state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
+ }
+ return nil
+}
+
+// TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
+// See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512
+func TestCopyOfCopy(t *testing.T) {
+ state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+ addr := common.HexToAddress("aaaa")
+ state.SetBalance(addr, big.NewInt(42))
+
+ if got := state.Copy().GetBalance(addr).Uint64(); got != 42 {
+ t.Fatalf("1st copy fail, expected 42, got %v", got)
+ }
+ if got := state.Copy().Copy().GetBalance(addr).Uint64(); got != 42 {
+ t.Fatalf("2nd copy fail, expected 42, got %v", got)
+ }
+}
+
+func TestStateDBAccessList(t *testing.T) {
+ // Some helpers
+ addr := func(a string) common.Address {
+ return common.HexToAddress(a)
+ }
+ slot := func(a string) common.Hash {
+ return common.HexToHash(a)
+ }
+
+ memDb := rawdb.NewMemoryDatabase()
+ db := NewDatabase(memDb)
+ state, _ := New(common.Hash{}, db, nil)
+ state.accessList = newAccessList()
+
+ verifyAddrs := func(astrings ...string) {
+ t.Helper()
+ // convert to common.Address form
+ var addresses []common.Address
+ var addressMap = make(map[common.Address]struct{})
+ for _, astring := range astrings {
+ address := addr(astring)
+ addresses = append(addresses, address)
+ addressMap[address] = struct{}{}
+ }
+ // Check that the given addresses are in the access list
+ for _, address := range addresses {
+ if !state.AddressInAccessList(address) {
+ t.Fatalf("expected %x to be in access list", address)
+ }
+ }
+ // Check that only the expected addresses are present in the access list
+ for address := range state.accessList.addresses {
+ if _, exist := addressMap[address]; !exist {
+ t.Fatalf("extra address %x in access list", address)
+ }
+ }
+ }
+ verifySlots := func(addrString string, slotStrings ...string) {
+ if !state.AddressInAccessList(addr(addrString)) {
+ t.Fatalf("scope missing address/slots %v", addrString)
+ }
+ var address = addr(addrString)
+ // convert to common.Hash form
+ var slots []common.Hash
+ var slotMap = make(map[common.Hash]struct{})
+ for _, slotString := range slotStrings {
+ s := slot(slotString)
+ slots = append(slots, s)
+ slotMap[s] = struct{}{}
+ }
+ // Check that the expected items are in the access list
+ for i, s := range slots {
+ if _, slotPresent := state.SlotInAccessList(address, s); !slotPresent {
+ t.Fatalf("input %d: scope missing slot %v (address %v)", i, s, addrString)
+ }
+ }
+ // Check that no extra elements are in the access list
+ index := state.accessList.addresses[address]
+ if index >= 0 {
+ stateSlots := state.accessList.slots[index]
+ for s := range stateSlots {
+ if _, slotPresent := slotMap[s]; !slotPresent {
+ t.Fatalf("scope has extra slot %v (address %v)", s, addrString)
+ }
+ }
+ }
+ }
+
+ state.AddAddressToAccessList(addr("aa")) // 1
+ state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3
+ state.AddSlotToAccessList(addr("bb"), slot("02")) // 4
+ verifyAddrs("aa", "bb")
+ verifySlots("bb", "01", "02")
+
+ // Make a copy
+ stateCopy1 := state.Copy()
+ if exp, got := 4, state.journal.length(); exp != got {
+ t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
+ }
+
+ // same again, should cause no journal entries
+ state.AddSlotToAccessList(addr("bb"), slot("01"))
+ state.AddSlotToAccessList(addr("bb"), slot("02"))
+ state.AddAddressToAccessList(addr("aa"))
+ if exp, got := 4, state.journal.length(); exp != got {
+ t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
+ }
+ // some new ones
+ state.AddSlotToAccessList(addr("bb"), slot("03")) // 5
+ state.AddSlotToAccessList(addr("aa"), slot("01")) // 6
+ state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8
+ state.AddAddressToAccessList(addr("cc"))
+ if exp, got := 8, state.journal.length(); exp != got {
+ t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
+ }
+
+ verifyAddrs("aa", "bb", "cc")
+ verifySlots("aa", "01")
+ verifySlots("bb", "01", "02", "03")
+ verifySlots("cc", "01")
+
+ // now start rolling back changes
+ state.journal.revert(state, 7)
+ if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok {
+ t.Fatalf("slot present, expected missing")
+ }
+ verifyAddrs("aa", "bb", "cc")
+ verifySlots("aa", "01")
+ verifySlots("bb", "01", "02", "03")
+
+ state.journal.revert(state, 6)
+ if state.AddressInAccessList(addr("cc")) {
+ t.Fatalf("addr present, expected missing")
+ }
+ verifyAddrs("aa", "bb")
+ verifySlots("aa", "01")
+ verifySlots("bb", "01", "02", "03")
+
+ state.journal.revert(state, 5)
+ if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok {
+ t.Fatalf("slot present, expected missing")
+ }
+ verifyAddrs("aa", "bb")
+ verifySlots("bb", "01", "02", "03")
+
+ state.journal.revert(state, 4)
+ if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok {
+ t.Fatalf("slot present, expected missing")
+ }
+ verifyAddrs("aa", "bb")
+ verifySlots("bb", "01", "02")
+
+ state.journal.revert(state, 3)
+ if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok {
+ t.Fatalf("slot present, expected missing")
+ }
+ verifyAddrs("aa", "bb")
+ verifySlots("bb", "01")
+
+ state.journal.revert(state, 2)
+ if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok {
+ t.Fatalf("slot present, expected missing")
+ }
+ verifyAddrs("aa", "bb")
+
+ state.journal.revert(state, 1)
+ if state.AddressInAccessList(addr("bb")) {
+ t.Fatalf("addr present, expected missing")
+ }
+ verifyAddrs("aa")
+
+ state.journal.revert(state, 0)
+ if state.AddressInAccessList(addr("aa")) {
+ t.Fatalf("addr present, expected missing")
+ }
+ if got, exp := len(state.accessList.addresses), 0; got != exp {
+ t.Fatalf("expected empty, got %d", got)
+ }
+ if got, exp := len(state.accessList.slots), 0; got != exp {
+ t.Fatalf("expected empty, got %d", got)
+ }
+ // Check the copy
+ // Make a copy
+ state = stateCopy1
+ verifyAddrs("aa", "bb")
+ verifySlots("bb", "01", "02")
+ if got, exp := len(state.accessList.addresses), 2; got != exp {
+ t.Fatalf("expected empty, got %d", got)
+ }
+ if got, exp := len(state.accessList.slots), 1; got != exp {
+ t.Fatalf("expected empty, got %d", got)
+ }
+}
+
+func TestStateDBTransientStorage(t *testing.T) {
+ memDb := rawdb.NewMemoryDatabase()
+ db := NewDatabase(memDb)
+ state, _ := New(common.Hash{}, db, nil)
+
+ key := common.Hash{0x01}
+ value := common.Hash{0x02}
+ addr := common.Address{}
+
+ state.SetTransientState(addr, key, value)
+ if exp, got := 1, state.journal.length(); exp != got {
+ t.Fatalf("journal length mismatch: have %d, want %d", got, exp)
+ }
+ // the retrieved value should equal what was set
+ if got := state.GetTransientState(addr, key); got != value {
+ t.Fatalf("transient storage mismatch: have %x, want %x", got, value)
+ }
+
+ // revert the transient state being set and then check that the
+ // value is now the empty hash
+ state.journal.revert(state, 0)
+ if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got {
+ t.Fatalf("transient storage mismatch: have %x, want %x", got, exp)
+ }
+
+ // set transient state and then copy the statedb and ensure that
+ // the transient state is copied
+ state.SetTransientState(addr, key, value)
+ cpy := state.Copy()
+ if got := cpy.GetTransientState(addr, key); got != value {
+ t.Fatalf("transient storage mismatch: have %x, want %x", got, value)
+ }
+}
diff --git a/trie_by_cid/state/transient_storage.go b/trie_by_cid/state/transient_storage.go
new file mode 100644
index 0000000..66e563e
--- /dev/null
+++ b/trie_by_cid/state/transient_storage.go
@@ -0,0 +1,55 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// transientStorage is a representation of EIP-1153 "Transient Storage".
+type transientStorage map[common.Address]Storage
+
+// newTransientStorage creates a new instance of a transientStorage.
+func newTransientStorage() transientStorage {
+ return make(transientStorage)
+}
+
+// Set sets the transient-storage `value` for `key` at the given `addr`.
+func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
+ if _, ok := t[addr]; !ok {
+ t[addr] = make(Storage)
+ }
+ t[addr][key] = value
+}
+
+// Get gets the transient storage for `key` at the given `addr`.
+func (t transientStorage) Get(addr common.Address, key common.Hash) common.Hash {
+ val, ok := t[addr]
+ if !ok {
+ return common.Hash{}
+ }
+ return val[key]
+}
+
+// Copy does a deep copy of the transientStorage
+func (t transientStorage) Copy() transientStorage {
+ storage := make(transientStorage)
+ for key, value := range t {
+ storage[key] = value.Copy()
+ }
+ return storage
+}
diff --git a/trie_by_cid/state/trie_prefetcher.go b/trie_by_cid/state/trie_prefetcher.go
new file mode 100644
index 0000000..5dd1b5b
--- /dev/null
+++ b/trie_by_cid/state/trie_prefetcher.go
@@ -0,0 +1,354 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/metrics"
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ // triePrefetchMetricsPrefix is the prefix under which to publish the metrics.
+ triePrefetchMetricsPrefix = "trie/prefetch/"
+)
+
+// triePrefetcher is an active prefetcher, which receives accounts or storage
+// items and does trie-loading of them. The goal is to get as much useful content
+// into the caches as possible.
+//
+// Note, the prefetcher's API is not thread safe.
+type triePrefetcher struct {
+ db Database // Database to fetch trie nodes through
+ root common.Hash // Root hash of the account trie for metrics
+ fetches map[string]Trie // Partially or fully fetcher tries
+ fetchers map[string]*subfetcher // Subfetchers for each trie
+
+ deliveryMissMeter metrics.Meter
+ accountLoadMeter metrics.Meter
+ accountDupMeter metrics.Meter
+ accountSkipMeter metrics.Meter
+ accountWasteMeter metrics.Meter
+ storageLoadMeter metrics.Meter
+ storageDupMeter metrics.Meter
+ storageSkipMeter metrics.Meter
+ storageWasteMeter metrics.Meter
+}
+
+func newTriePrefetcher(db Database, root common.Hash, namespace string) *triePrefetcher {
+ prefix := triePrefetchMetricsPrefix + namespace
+ p := &triePrefetcher{
+ db: db,
+ root: root,
+ fetchers: make(map[string]*subfetcher), // Active prefetchers use the fetchers map
+
+ deliveryMissMeter: metrics.GetOrRegisterMeter(prefix+"/deliverymiss", nil),
+ accountLoadMeter: metrics.GetOrRegisterMeter(prefix+"/account/load", nil),
+ accountDupMeter: metrics.GetOrRegisterMeter(prefix+"/account/dup", nil),
+ accountSkipMeter: metrics.GetOrRegisterMeter(prefix+"/account/skip", nil),
+ accountWasteMeter: metrics.GetOrRegisterMeter(prefix+"/account/waste", nil),
+ storageLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/load", nil),
+ storageDupMeter: metrics.GetOrRegisterMeter(prefix+"/storage/dup", nil),
+ storageSkipMeter: metrics.GetOrRegisterMeter(prefix+"/storage/skip", nil),
+ storageWasteMeter: metrics.GetOrRegisterMeter(prefix+"/storage/waste", nil),
+ }
+ return p
+}
+
+// close iterates over all the subfetchers, aborts any that were left spinning
+// and reports the stats to the metrics subsystem.
+func (p *triePrefetcher) close() {
+ for _, fetcher := range p.fetchers {
+ fetcher.abort() // safe to do multiple times
+
+ if metrics.Enabled {
+ if fetcher.root == p.root {
+ p.accountLoadMeter.Mark(int64(len(fetcher.seen)))
+ p.accountDupMeter.Mark(int64(fetcher.dups))
+ p.accountSkipMeter.Mark(int64(len(fetcher.tasks)))
+
+ for _, key := range fetcher.used {
+ delete(fetcher.seen, string(key))
+ }
+ p.accountWasteMeter.Mark(int64(len(fetcher.seen)))
+ } else {
+ p.storageLoadMeter.Mark(int64(len(fetcher.seen)))
+ p.storageDupMeter.Mark(int64(fetcher.dups))
+ p.storageSkipMeter.Mark(int64(len(fetcher.tasks)))
+
+ for _, key := range fetcher.used {
+ delete(fetcher.seen, string(key))
+ }
+ p.storageWasteMeter.Mark(int64(len(fetcher.seen)))
+ }
+ }
+ }
+ // Clear out all fetchers (will crash on a second call, deliberate)
+ p.fetchers = nil
+}
+
+// copy creates a deep-but-inactive copy of the trie prefetcher. Any trie data
+// already loaded will be copied over, but no goroutines will be started. This
+// is mostly used in the miner which creates a copy of it's actively mutated
+// state to be sealed while it may further mutate the state.
+func (p *triePrefetcher) copy() *triePrefetcher {
+ copy := &triePrefetcher{
+ db: p.db,
+ root: p.root,
+ fetches: make(map[string]Trie), // Active prefetchers use the fetches map
+
+ deliveryMissMeter: p.deliveryMissMeter,
+ accountLoadMeter: p.accountLoadMeter,
+ accountDupMeter: p.accountDupMeter,
+ accountSkipMeter: p.accountSkipMeter,
+ accountWasteMeter: p.accountWasteMeter,
+ storageLoadMeter: p.storageLoadMeter,
+ storageDupMeter: p.storageDupMeter,
+ storageSkipMeter: p.storageSkipMeter,
+ storageWasteMeter: p.storageWasteMeter,
+ }
+ // If the prefetcher is already a copy, duplicate the data
+ if p.fetches != nil {
+ for root, fetch := range p.fetches {
+ if fetch == nil {
+ continue
+ }
+ copy.fetches[root] = p.db.CopyTrie(fetch)
+ }
+ return copy
+ }
+ // Otherwise we're copying an active fetcher, retrieve the current states
+ for id, fetcher := range p.fetchers {
+ copy.fetches[id] = fetcher.peek()
+ }
+ return copy
+}
+
+// prefetch schedules a batch of trie items to prefetch.
+func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]byte) {
+ // If the prefetcher is an inactive one, bail out
+ if p.fetches != nil {
+ return
+ }
+ // Active fetcher, schedule the retrievals
+ id := p.trieID(owner, root)
+ fetcher := p.fetchers[id]
+ if fetcher == nil {
+ fetcher = newSubfetcher(p.db, p.root, owner, root)
+ p.fetchers[id] = fetcher
+ }
+ fetcher.schedule(keys)
+}
+
+// trie returns the trie matching the root hash, or nil if the prefetcher doesn't
+// have it.
+func (p *triePrefetcher) trie(owner common.Hash, root common.Hash) Trie {
+ // If the prefetcher is inactive, return from existing deep copies
+ id := p.trieID(owner, root)
+ if p.fetches != nil {
+ trie := p.fetches[id]
+ if trie == nil {
+ p.deliveryMissMeter.Mark(1)
+ return nil
+ }
+ return p.db.CopyTrie(trie)
+ }
+ // Otherwise the prefetcher is active, bail if no trie was prefetched for this root
+ fetcher := p.fetchers[id]
+ if fetcher == nil {
+ p.deliveryMissMeter.Mark(1)
+ return nil
+ }
+ // Interrupt the prefetcher if it's by any chance still running and return
+ // a copy of any pre-loaded trie.
+ fetcher.abort() // safe to do multiple times
+
+ trie := fetcher.peek()
+ if trie == nil {
+ p.deliveryMissMeter.Mark(1)
+ return nil
+ }
+ return trie
+}
+
+// used marks a batch of state items used to allow creating statistics as to
+// how useful or wasteful the prefetcher is.
+func (p *triePrefetcher) used(owner common.Hash, root common.Hash, used [][]byte) {
+ if fetcher := p.fetchers[p.trieID(owner, root)]; fetcher != nil {
+ fetcher.used = used
+ }
+}
+
+// trieID returns an unique trie identifier consists the trie owner and root hash.
+func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string {
+ return string(append(owner.Bytes(), root.Bytes()...))
+}
+
+// subfetcher is a trie fetcher goroutine responsible for pulling entries for a
+// single trie. It is spawned when a new root is encountered and lives until the
+// main prefetcher is paused and either all requested items are processed or if
+// the trie being worked on is retrieved from the prefetcher.
+type subfetcher struct {
+ db Database // Database to load trie nodes through
+ state common.Hash // Root hash of the state to prefetch
+ owner common.Hash // Owner of the trie, usually account hash
+ root common.Hash // Root hash of the trie to prefetch
+ trie Trie // Trie being populated with nodes
+
+ tasks [][]byte // Items queued up for retrieval
+ lock sync.Mutex // Lock protecting the task queue
+
+ wake chan struct{} // Wake channel if a new task is scheduled
+ stop chan struct{} // Channel to interrupt processing
+ term chan struct{} // Channel to signal interruption
+ copy chan chan Trie // Channel to request a copy of the current trie
+
+ seen map[string]struct{} // Tracks the entries already loaded
+ dups int // Number of duplicate preload tasks
+ used [][]byte // Tracks the entries used in the end
+}
+
+// newSubfetcher creates a goroutine to prefetch state items belonging to a
+// particular root hash.
+func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash) *subfetcher {
+ sf := &subfetcher{
+ db: db,
+ state: state,
+ owner: owner,
+ root: root,
+ wake: make(chan struct{}, 1),
+ stop: make(chan struct{}),
+ term: make(chan struct{}),
+ copy: make(chan chan Trie),
+ seen: make(map[string]struct{}),
+ }
+ go sf.loop()
+ return sf
+}
+
+// schedule adds a batch of trie keys to the queue to prefetch.
+func (sf *subfetcher) schedule(keys [][]byte) {
+ // Append the tasks to the current queue
+ sf.lock.Lock()
+ sf.tasks = append(sf.tasks, keys...)
+ sf.lock.Unlock()
+
+ // Notify the prefetcher, it's fine if it's already terminated
+ select {
+ case sf.wake <- struct{}{}:
+ default:
+ }
+}
+
+// peek tries to retrieve a deep copy of the fetcher's trie in whatever form it
+// is currently.
+func (sf *subfetcher) peek() Trie {
+ ch := make(chan Trie)
+ select {
+ case sf.copy <- ch:
+ // Subfetcher still alive, return copy from it
+ return <-ch
+
+ case <-sf.term:
+ // Subfetcher already terminated, return a copy directly
+ if sf.trie == nil {
+ return nil
+ }
+ return sf.db.CopyTrie(sf.trie)
+ }
+}
+
+// abort interrupts the subfetcher immediately. It is safe to call abort multiple
+// times but it is not thread safe.
+func (sf *subfetcher) abort() {
+ select {
+ case <-sf.stop:
+ default:
+ close(sf.stop)
+ }
+ <-sf.term
+}
+
+// loop waits for new tasks to be scheduled and keeps loading them until it runs
+// out of tasks or its underlying trie is retrieved for committing.
+func (sf *subfetcher) loop() {
+ // No matter how the loop stops, signal anyone waiting that it's terminated
+ defer close(sf.term)
+
+ // Start by opening the trie and stop processing if it fails
+ if sf.owner == (common.Hash{}) {
+ trie, err := sf.db.OpenTrie(sf.root)
+ if err != nil {
+ log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
+ return
+ }
+ sf.trie = trie
+ } else {
+ trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root)
+ if err != nil {
+ log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err)
+ return
+ }
+ sf.trie = trie
+ }
+ // Trie opened successfully, keep prefetching items
+ for {
+ select {
+ case <-sf.wake:
+ // Subfetcher was woken up, retrieve any tasks to avoid spinning the lock
+ sf.lock.Lock()
+ tasks := sf.tasks
+ sf.tasks = nil
+ sf.lock.Unlock()
+
+ // Prefetch any tasks until the loop is interrupted
+ for i, task := range tasks {
+ select {
+ case <-sf.stop:
+ // If termination is requested, add any leftover back and return
+ sf.lock.Lock()
+ sf.tasks = append(sf.tasks, tasks[i:]...)
+ sf.lock.Unlock()
+ return
+
+ case ch := <-sf.copy:
+ // Somebody wants a copy of the current trie, grant them
+ ch <- sf.db.CopyTrie(sf.trie)
+
+ default:
+ // No termination request yet, prefetch the next entry
+ if _, ok := sf.seen[string(task)]; ok {
+ sf.dups++
+ } else {
+ sf.trie.TryGet(task)
+ sf.seen[string(task)] = struct{}{}
+ }
+ }
+ }
+
+ case ch := <-sf.copy:
+ // Somebody wants a copy of the current trie, grant them
+ ch <- sf.db.CopyTrie(sf.trie)
+
+ case <-sf.stop:
+ // Termination is requested, abort and leave remaining tasks
+ return
+ }
+ }
+}
diff --git a/trie_by_cid/state/trie_prefetcher_test.go b/trie_by_cid/state/trie_prefetcher_test.go
new file mode 100644
index 0000000..cb0b67d
--- /dev/null
+++ b/trie_by_cid/state/trie_prefetcher_test.go
@@ -0,0 +1,110 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "math/big"
+ "testing"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+)
+
+func filledStateDB() *StateDB {
+ state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
+
+ // Create an account and check if the retrieved balance is correct
+ addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe")
+ skey := common.HexToHash("aaa")
+ sval := common.HexToHash("bbb")
+
+ state.SetBalance(addr, big.NewInt(42)) // Change the account trie
+ state.SetCode(addr, []byte("hello")) // Change an external metadata
+ state.SetState(addr, skey, sval) // Change the storage trie
+ for i := 0; i < 100; i++ {
+ sk := common.BigToHash(big.NewInt(int64(i)))
+ state.SetState(addr, sk, sk) // Change the storage trie
+ }
+ return state
+}
+
+func TestCopyAndClose(t *testing.T) {
+ db := filledStateDB()
+ prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
+ skey := common.HexToHash("aaa")
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ time.Sleep(1 * time.Second)
+ a := prefetcher.trie(common.Hash{}, db.originalRoot)
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ b := prefetcher.trie(common.Hash{}, db.originalRoot)
+ cpy := prefetcher.copy()
+ cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ c := cpy.trie(common.Hash{}, db.originalRoot)
+ prefetcher.close()
+ cpy2 := cpy.copy()
+ cpy2.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ d := cpy2.trie(common.Hash{}, db.originalRoot)
+ cpy.close()
+ cpy2.close()
+ if a.Hash() != b.Hash() || a.Hash() != c.Hash() || a.Hash() != d.Hash() {
+ t.Fatalf("Invalid trie, hashes should be equal: %v %v %v %v", a.Hash(), b.Hash(), c.Hash(), d.Hash())
+ }
+}
+
+func TestUseAfterClose(t *testing.T) {
+ db := filledStateDB()
+ prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
+ skey := common.HexToHash("aaa")
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ a := prefetcher.trie(common.Hash{}, db.originalRoot)
+ prefetcher.close()
+ b := prefetcher.trie(common.Hash{}, db.originalRoot)
+ if a == nil {
+ t.Fatal("Prefetching before close should not return nil")
+ }
+ if b != nil {
+ t.Fatal("Trie after close should return nil")
+ }
+}
+
+func TestCopyClose(t *testing.T) {
+ db := filledStateDB()
+ prefetcher := newTriePrefetcher(db.db, db.originalRoot, "")
+ skey := common.HexToHash("aaa")
+ prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()})
+ cpy := prefetcher.copy()
+ a := prefetcher.trie(common.Hash{}, db.originalRoot)
+ b := cpy.trie(common.Hash{}, db.originalRoot)
+ prefetcher.close()
+ c := prefetcher.trie(common.Hash{}, db.originalRoot)
+ d := cpy.trie(common.Hash{}, db.originalRoot)
+ if a == nil {
+ t.Fatal("Prefetching before close should not return nil")
+ }
+ if b == nil {
+ t.Fatal("Copy trie should return nil")
+ }
+ if c != nil {
+ t.Fatal("Trie after close should return nil")
+ }
+ if d == nil {
+ t.Fatal("Copy trie should not return nil")
+ }
+}
diff --git a/trie_by_cid/trie/committer.go b/trie_by_cid/trie/committer.go
new file mode 100644
index 0000000..9f97887
--- /dev/null
+++ b/trie_by_cid/trie/committer.go
@@ -0,0 +1,196 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// leaf represents a trie leaf node
+type leaf struct {
+ blob []byte // raw blob of leaf
+ parent common.Hash // the hash of parent node
+}
+
+// committer is the tool used for the trie Commit operation. The committer will
+// capture all dirty nodes during the commit process and keep them cached in
+// insertion order.
+type committer struct {
+ nodes *NodeSet
+ collectLeaf bool
+}
+
+// newCommitter creates a new committer or picks one from the pool.
+func newCommitter(nodeset *NodeSet, collectLeaf bool) *committer {
+ return &committer{
+ nodes: nodeset,
+ collectLeaf: collectLeaf,
+ }
+}
+
+// Commit collapses a node down into a hash node.
+func (c *committer) Commit(n node) hashNode {
+ return c.commit(nil, n).(hashNode)
+}
+
+// commit collapses a node down into a hash node and returns it.
+func (c *committer) commit(path []byte, n node) node {
+ // if this path is clean, use available cached data
+ hash, dirty := n.cache()
+ if hash != nil && !dirty {
+ return hash
+ }
+ // Commit children, then parent, and remove the dirty flag.
+ switch cn := n.(type) {
+ case *shortNode:
+ // Commit child
+ collapsed := cn.copy()
+
+ // If the child is fullNode, recursively commit,
+ // otherwise it can only be hashNode or valueNode.
+ if _, ok := cn.Val.(*fullNode); ok {
+ collapsed.Val = c.commit(append(path, cn.Key...), cn.Val)
+ }
+ // The key needs to be copied, since we're adding it to the
+ // modified nodeset.
+ collapsed.Key = hexToCompact(cn.Key)
+ hashedNode := c.store(path, collapsed)
+ if hn, ok := hashedNode.(hashNode); ok {
+ return hn
+ }
+ return collapsed
+ case *fullNode:
+ hashedKids := c.commitChildren(path, cn)
+ collapsed := cn.copy()
+ collapsed.Children = hashedKids
+
+ hashedNode := c.store(path, collapsed)
+ if hn, ok := hashedNode.(hashNode); ok {
+ return hn
+ }
+ return collapsed
+ case hashNode:
+ return cn
+ default:
+ // nil, valuenode shouldn't be committed
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+}
+
+// commitChildren commits the children of the given fullnode
+func (c *committer) commitChildren(path []byte, n *fullNode) [17]node {
+ var children [17]node
+ for i := 0; i < 16; i++ {
+ child := n.Children[i]
+ if child == nil {
+ continue
+ }
+ // If it's the hashed child, save the hash value directly.
+ // Note: it's impossible that the child in range [0, 15]
+ // is a valueNode.
+ if hn, ok := child.(hashNode); ok {
+ children[i] = hn
+ continue
+ }
+ // Commit the child recursively and store the "hashed" value.
+ // Note the returned node can be some embedded nodes, so it's
+ // possible the type is not hashNode.
+ children[i] = c.commit(append(path, byte(i)), child)
+ }
+ // For the 17th child, it's possible the type is valuenode.
+ if n.Children[16] != nil {
+ children[16] = n.Children[16]
+ }
+ return children
+}
+
+// store hashes the node n and adds it to the modified nodeset. If leaf collection
+// is enabled, leaf nodes will be tracked in the modified nodeset as well.
+func (c *committer) store(path []byte, n node) node {
+ // Larger nodes are replaced by their hash and stored in the database.
+ var hash, _ = n.cache()
+
+ // This was not generated - must be a small node stored in the parent.
+ // In theory, we should check if the node is leaf here (embedded node
+ // usually is leaf node). But small value (less than 32bytes) is not
+ // our target (leaves in account trie only).
+ if hash == nil {
+ // The node is embedded in its parent, in other words, this node
+ // will not be stored in the database independently, mark it as
+ // deleted only if the node was existent in database before.
+ if _, ok := c.nodes.accessList[string(path)]; ok {
+ c.nodes.markDeleted(path)
+ }
+ return n
+ }
+ // We have the hash already, estimate the RLP encoding-size of the node.
+ // The size is used for mem tracking, does not need to be exact
+ var (
+ size = estimateSize(n)
+ nhash = common.BytesToHash(hash)
+ mnode = &memoryNode{
+ hash: nhash,
+ node: simplifyNode(n),
+ size: uint16(size),
+ }
+ )
+ // Collect the dirty node to nodeset for return.
+ c.nodes.markUpdated(path, mnode)
+
+ // Collect the corresponding leaf node if it's required. We don't check
+ // full node since it's impossible to store value in fullNode. The key
+ // length of leaves should be exactly same.
+ if c.collectLeaf {
+ if sn, ok := n.(*shortNode); ok {
+ if val, ok := sn.Val.(valueNode); ok {
+ c.nodes.addLeaf(&leaf{blob: val, parent: nhash})
+ }
+ }
+ }
+ return hash
+}
+
+// estimateSize estimates the size of an rlp-encoded node, without actually
+// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie
+// with 1000 leaves, the only errors above 1% are on small shortnodes, where this
+// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35)
+func estimateSize(n node) int {
+ switch n := n.(type) {
+ case *shortNode:
+ // A short node contains a compacted key, and a value.
+ return 3 + len(n.Key) + estimateSize(n.Val)
+ case *fullNode:
+ // A full node contains up to 16 hashes (some nils), and a key
+ s := 3
+ for i := 0; i < 16; i++ {
+ if child := n.Children[i]; child != nil {
+ s += estimateSize(child)
+ } else {
+ s++
+ }
+ }
+ return s
+ case valueNode:
+ return 1 + len(n)
+ case hashNode:
+ return 1 + len(n)
+ default:
+ panic(fmt.Sprintf("node type %T", n))
+ }
+}
diff --git a/trie_by_cid/trie/database.go b/trie_by_cid/trie/database.go
index d13cb49..ac05f98 100644
--- a/trie_by_cid/trie/database.go
+++ b/trie_by_cid/trie/database.go
@@ -18,25 +18,50 @@ package trie
import (
"errors"
+ "runtime"
+ "sync"
+ "time"
"github.com/VictoriaMetrics/fastcache"
+ "github.com/cerc-io/ipld-eth-statedb/internal"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
+ "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
- "github.com/ipfs/go-cid"
+ log "github.com/sirupsen/logrus"
)
-type CidKey = cid.Cid
-
-func isEmpty(key CidKey) bool {
- return len(key.KeyString()) == 0
-}
-
-// Database is an intermediate read-only layer between the trie data structures and
-// the disk database. This trie Database is thread safe in providing individual,
-// independent node access.
+// Database is an intermediate write layer between the trie data structures and
+// the disk database. The aim is to accumulate trie writes in-memory and only
+// periodically flush a couple tries to disk, garbage collecting the remainder.
+//
+// Note, the trie Database is **not** thread safe in its mutations, but it **is**
+// thread safe in providing individual, independent node access. The rationale
+// behind this split design is to provide read access to RPC handlers and sync
+// servers even while the trie is executing expensive garbage collection.
type Database struct {
- diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes
- cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
+ diskdb ethdb.Database // Persistent storage for matured trie nodes
+
+ cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
+ dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes
+ oldest common.Hash // Oldest tracked node, flush-list head
+ newest common.Hash // Newest tracked node, flush-list tail
+
+ gctime time.Duration // Time spent on garbage collection since last commit
+ gcnodes uint64 // Nodes garbage collected since last commit
+ gcsize common.StorageSize // Data storage garbage collected since last commit
+
+ flushtime time.Duration // Time spent on data flushing since last commit
+ flushnodes uint64 // Nodes flushed since last commit
+ flushsize common.StorageSize // Data storage flushed since last commit
+
+ dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata)
+ childrenSize common.StorageSize // Storage size of the external children tracking
+ preimages *preimageStore // The store for caching preimages
+
+ lock sync.RWMutex
}
// Config defines all necessary options for database.
@@ -46,14 +71,14 @@ type Config = trie.Config
// NewDatabase creates a new trie database to store ephemeral trie content before
// its written out to disk or garbage collected. No read cache is created, so all
// data retrievals will hit the underlying disk database.
-func NewDatabase(diskdb ethdb.KeyValueStore) *Database {
+func NewDatabase(diskdb ethdb.Database) *Database {
return NewDatabaseWithConfig(diskdb, nil)
}
// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
-// before it's written out to disk or garbage collected. It also acts as a read cache
+// before its written out to disk or garbage collected. It also acts as a read cache
// for nodes loaded from disk.
-func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database {
+func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database {
var cleans *fastcache.Cache
if config != nil && config.Cache > 0 {
if config.Journal == "" {
@@ -62,44 +87,354 @@ func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database
cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024)
}
}
+ var preimage *preimageStore
+ if config != nil && config.Preimages {
+ preimage = newPreimageStore(diskdb)
+ }
db := &Database{
diskdb: diskdb,
cleans: cleans,
+ dirties: map[common.Hash]*cachedNode{{}: {
+ children: make(map[common.Hash]uint16),
+ }},
+ preimages: preimage,
}
return db
}
-// DiskDB retrieves the persistent storage backing the trie database.
-func (db *Database) DiskDB() ethdb.KeyValueStore {
- return db.diskdb
+// insert inserts a simplified trie node into the memory database.
+// All nodes inserted by this function will be reference tracked
+// and in theory should only used for **trie nodes** insertion.
+func (db *Database) insert(hash common.Hash, size int, node node) {
+ // If the node's already cached, skip
+ if _, ok := db.dirties[hash]; ok {
+ return
+ }
+ memcacheDirtyWriteMeter.Mark(int64(size))
+
+ // Create the cached entry for this node
+ entry := &cachedNode{
+ node: node,
+ size: uint16(size),
+ flushPrev: db.newest,
+ }
+ entry.forChilds(func(child common.Hash) {
+ if c := db.dirties[child]; c != nil {
+ c.parents++
+ }
+ })
+ db.dirties[hash] = entry
+
+ // Update the flush-list endpoints
+ if db.oldest == (common.Hash{}) {
+ db.oldest, db.newest = hash, hash
+ } else {
+ db.dirties[db.newest].flushNext, db.newest = hash, hash
+ }
+ db.dirtiesSize += common.StorageSize(common.HashLength + entry.size)
}
-// Node retrieves an encoded trie node by CID. If it cannot be found
-// cached in memory, it queries the persistent database.
-func (db *Database) Node(key CidKey) ([]byte, error) {
+// Node retrieves an encoded cached trie node from memory. If it cannot be found
+// cached, the method queries the persistent database for the content.
+func (db *Database) Node(hash common.Hash, codec uint64) ([]byte, error) {
// It doesn't make sense to retrieve the metaroot
- if isEmpty(key) {
+ if hash == (common.Hash{}) {
return nil, errors.New("not found")
}
- cidbytes := key.Bytes()
// Retrieve the node from the clean cache if available
if db.cleans != nil {
- if enc := db.cleans.Get(nil, cidbytes); enc != nil {
+ if enc := db.cleans.Get(nil, hash[:]); enc != nil {
+ memcacheCleanHitMeter.Mark(1)
+ memcacheCleanReadMeter.Mark(int64(len(enc)))
return enc, nil
}
}
+ // Retrieve the node from the dirty cache if available
+ db.lock.RLock()
+ dirty := db.dirties[hash]
+ db.lock.RUnlock()
+
+ if dirty != nil {
+ memcacheDirtyHitMeter.Mark(1)
+ memcacheDirtyReadMeter.Mark(int64(dirty.size))
+ return dirty.rlp(), nil
+ }
+ memcacheDirtyMissMeter.Mark(1)
// Content unavailable in memory, attempt to retrieve from disk
- enc, err := db.diskdb.Get(cidbytes)
+ cid, err := internal.Keccak256ToCid(codec, hash[:])
+ if err != nil {
+ return nil, err
+ }
+ enc, err := db.diskdb.Get(cid.Bytes())
if err != nil {
return nil, err
}
-
if len(enc) != 0 {
if db.cleans != nil {
- db.cleans.Set(cidbytes, enc)
+ db.cleans.Set(hash[:], enc)
+ memcacheCleanMissMeter.Mark(1)
+ memcacheCleanWriteMeter.Mark(int64(len(enc)))
}
return enc, nil
}
return nil, errors.New("not found")
}
+
+// Nodes retrieves the hashes of all the nodes cached within the memory database.
+// This method is extremely expensive and should only be used to validate internal
+// states in test code.
+func (db *Database) Nodes() []common.Hash {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ var hashes = make([]common.Hash, 0, len(db.dirties))
+ for hash := range db.dirties {
+ if hash != (common.Hash{}) { // Special case for "root" references/nodes
+ hashes = append(hashes, hash)
+ }
+ }
+ return hashes
+}
+
+// Reference adds a new reference from a parent node to a child node.
+// This function is used to add reference between internal trie node
+// and external node(e.g. storage trie root), all internal trie nodes
+// are referenced together by database itself.
+func (db *Database) Reference(child common.Hash, parent common.Hash) {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ db.reference(child, parent)
+}
+
+// reference is the private locked version of Reference.
+func (db *Database) reference(child common.Hash, parent common.Hash) {
+ // If the node does not exist, it's a node pulled from disk, skip
+ node, ok := db.dirties[child]
+ if !ok {
+ return
+ }
+ // If the reference already exists, only duplicate for roots
+ if db.dirties[parent].children == nil {
+ db.dirties[parent].children = make(map[common.Hash]uint16)
+ db.childrenSize += cachedNodeChildrenSize
+ } else if _, ok = db.dirties[parent].children[child]; ok && parent != (common.Hash{}) {
+ return
+ }
+ node.parents++
+ db.dirties[parent].children[child]++
+ if db.dirties[parent].children[child] == 1 {
+ db.childrenSize += common.HashLength + 2 // uint16 counter
+ }
+}
+
+// Dereference removes an existing reference from a root node.
+func (db *Database) Dereference(root common.Hash) {
+ // Sanity check to ensure that the meta-root is not removed
+ if root == (common.Hash{}) {
+ log.Error("Attempted to dereference the trie cache meta root")
+ return
+ }
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
+ db.dereference(root, common.Hash{})
+
+ db.gcnodes += uint64(nodes - len(db.dirties))
+ db.gcsize += storage - db.dirtiesSize
+ db.gctime += time.Since(start)
+
+ memcacheGCTimeTimer.Update(time.Since(start))
+ memcacheGCSizeMeter.Mark(int64(storage - db.dirtiesSize))
+ memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties)))
+
+ log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start),
+ "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize)
+}
+
+// dereference is the private locked version of Dereference.
+func (db *Database) dereference(child common.Hash, parent common.Hash) {
+ // Dereference the parent-child
+ node := db.dirties[parent]
+
+ if node.children != nil && node.children[child] > 0 {
+ node.children[child]--
+ if node.children[child] == 0 {
+ delete(node.children, child)
+ db.childrenSize -= (common.HashLength + 2) // uint16 counter
+ }
+ }
+ // If the child does not exist, it's a previously committed node.
+ node, ok := db.dirties[child]
+ if !ok {
+ return
+ }
+ // If there are no more references to the child, delete it and cascade
+ if node.parents > 0 {
+ // This is a special cornercase where a node loaded from disk (i.e. not in the
+ // memcache any more) gets reinjected as a new node (short node split into full,
+ // then reverted into short), causing a cached node to have no parents. That is
+ // no problem in itself, but don't make maxint parents out of it.
+ node.parents--
+ }
+ if node.parents == 0 {
+ // Remove the node from the flush-list
+ switch child {
+ case db.oldest:
+ db.oldest = node.flushNext
+ db.dirties[node.flushNext].flushPrev = common.Hash{}
+ case db.newest:
+ db.newest = node.flushPrev
+ db.dirties[node.flushPrev].flushNext = common.Hash{}
+ default:
+ db.dirties[node.flushPrev].flushNext = node.flushNext
+ db.dirties[node.flushNext].flushPrev = node.flushPrev
+ }
+ // Dereference all children and delete the node
+ node.forChilds(func(hash common.Hash) {
+ db.dereference(hash, child)
+ })
+ delete(db.dirties, child)
+ db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size))
+ if node.children != nil {
+ db.childrenSize -= cachedNodeChildrenSize
+ }
+ }
+}
+
+// Update inserts the dirty nodes in provided nodeset into database and
+// link the account trie with multiple storage tries if necessary.
+func (db *Database) Update(nodes *MergedNodeSet) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Insert dirty nodes into the database. In the same tree, it must be
+ // ensured that children are inserted first, then parent so that children
+ // can be linked with their parent correctly.
+ //
+ // Note, the storage tries must be flushed before the account trie to
+ // retain the invariant that children go into the dirty cache first.
+ var order []common.Hash
+ for owner := range nodes.sets {
+ if owner == (common.Hash{}) {
+ continue
+ }
+ order = append(order, owner)
+ }
+ if _, ok := nodes.sets[common.Hash{}]; ok {
+ order = append(order, common.Hash{})
+ }
+ for _, owner := range order {
+ subset := nodes.sets[owner]
+ subset.forEachWithOrder(func(path string, n *memoryNode) {
+ if n.isDeleted() {
+ return // ignore deletion
+ }
+ db.insert(n.hash, int(n.size), n.node)
+ })
+ }
+ // Link up the account trie and storage trie if the node points
+ // to an account trie leaf.
+ if set, present := nodes.sets[common.Hash{}]; present {
+ for _, n := range set.leaves {
+ var account types.StateAccount
+ if err := rlp.DecodeBytes(n.blob, &account); err != nil {
+ return err
+ }
+ if account.Root != types.EmptyRootHash {
+ db.reference(account.Root, n.parent)
+ }
+ }
+ }
+ return nil
+}
+
+// Size returns the current storage size of the memory cache in front of the
+// persistent database layer.
+func (db *Database) Size() (common.StorageSize, common.StorageSize) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ // db.dirtiesSize only contains the useful data in the cache, but when reporting
+ // the total memory consumption, the maintenance metadata is also needed to be
+ // counted.
+ var metadataSize = common.StorageSize((len(db.dirties) - 1) * cachedNodeSize)
+ var metarootRefs = common.StorageSize(len(db.dirties[common.Hash{}].children) * (common.HashLength + 2))
+ var preimageSize common.StorageSize
+ if db.preimages != nil {
+ preimageSize = db.preimages.size()
+ }
+ return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, preimageSize
+}
+
+// GetReader retrieves a node reader belonging to the given state root.
+func (db *Database) GetReader(root common.Hash, codec uint64) Reader {
+ return &hashReader{db: db, codec: codec}
+}
+
+// hashReader is reader of hashDatabase which implements the Reader interface.
+type hashReader struct {
+ db *Database
+ codec uint64
+}
+
+// Node retrieves the trie node with the given node hash.
+func (reader *hashReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
+ blob, err := reader.NodeBlob(owner, path, hash)
+ if err != nil {
+ return nil, err
+ }
+ return decodeNodeUnsafe(hash[:], blob)
+}
+
+// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
+func (reader *hashReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
+ return reader.db.Node(hash, reader.codec)
+}
+
+// saveCache saves clean state cache to given directory path
+// using specified CPU cores.
+func (db *Database) saveCache(dir string, threads int) error {
+ if db.cleans == nil {
+ return nil
+ }
+ log.Info("Writing clean trie cache to disk", "path", dir, "threads", threads)
+
+ start := time.Now()
+ err := db.cleans.SaveToFileConcurrent(dir, threads)
+ if err != nil {
+ log.Error("Failed to persist clean trie cache", "error", err)
+ return err
+ }
+ log.Info("Persisted the clean trie cache", "path", dir, "elapsed", common.PrettyDuration(time.Since(start)))
+ return nil
+}
+
+// SaveCache atomically saves fast cache data to the given dir using all
+// available CPU cores.
+func (db *Database) SaveCache(dir string) error {
+ return db.saveCache(dir, runtime.GOMAXPROCS(0))
+}
+
+// SaveCachePeriodically atomically saves fast cache data to the given dir with
+// the specified interval. All dump operation will only use a single CPU core.
+func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) {
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ db.saveCache(dir, 1)
+ case <-stopCh:
+ return
+ }
+ }
+}
+
+// Scheme returns the node scheme used in the database.
+func (db *Database) Scheme() string {
+ return rawdb.HashScheme
+}
diff --git a/trie_by_cid/trie/database_metrics.go b/trie_by_cid/trie/database_metrics.go
new file mode 100644
index 0000000..55efc55
--- /dev/null
+++ b/trie_by_cid/trie/database_metrics.go
@@ -0,0 +1,27 @@
+package trie
+
+import "github.com/ethereum/go-ethereum/metrics"
+
+var (
+ memcacheCleanHitMeter = metrics.NewRegisteredMeter("trie/memcache/clean/hit", nil)
+ memcacheCleanMissMeter = metrics.NewRegisteredMeter("trie/memcache/clean/miss", nil)
+ memcacheCleanReadMeter = metrics.NewRegisteredMeter("trie/memcache/clean/read", nil)
+ memcacheCleanWriteMeter = metrics.NewRegisteredMeter("trie/memcache/clean/write", nil)
+
+ memcacheDirtyHitMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/hit", nil)
+ memcacheDirtyMissMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/miss", nil)
+ memcacheDirtyReadMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/read", nil)
+ memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/write", nil)
+
+ memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/flush/time", nil)
+ memcacheFlushNodesMeter = metrics.NewRegisteredMeter("trie/memcache/flush/nodes", nil)
+ memcacheFlushSizeMeter = metrics.NewRegisteredMeter("trie/memcache/flush/size", nil)
+
+ memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/gc/time", nil)
+ memcacheGCNodesMeter = metrics.NewRegisteredMeter("trie/memcache/gc/nodes", nil)
+ memcacheGCSizeMeter = metrics.NewRegisteredMeter("trie/memcache/gc/size", nil)
+
+ memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/commit/time", nil)
+ memcacheCommitNodesMeter = metrics.NewRegisteredMeter("trie/memcache/commit/nodes", nil)
+ memcacheCommitSizeMeter = metrics.NewRegisteredMeter("trie/memcache/commit/size", nil)
+)
diff --git a/trie_by_cid/trie/database_node.go b/trie_by_cid/trie/database_node.go
new file mode 100644
index 0000000..58037ae
--- /dev/null
+++ b/trie_by_cid/trie/database_node.go
@@ -0,0 +1,183 @@
+package trie
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/rlp"
+)
+
+// rawNode is a simple binary blob used to differentiate between collapsed trie
+// nodes and already encoded RLP binary blobs (while at the same time store them
+// in the same cache fields).
+type rawNode []byte
+
+func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
+func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
+
+func (n rawNode) EncodeRLP(w io.Writer) error {
+ _, err := w.Write(n)
+ return err
+}
+
+// rawFullNode represents only the useful data content of a full node, with the
+// caches and flags stripped out to minimize its data storage. This type honors
+// the same RLP encoding as the original parent.
+type rawFullNode [17]node
+
+func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
+func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") }
+
+func (n rawFullNode) EncodeRLP(w io.Writer) error {
+ eb := rlp.NewEncoderBuffer(w)
+ n.encode(eb)
+ return eb.Flush()
+}
+
+// rawShortNode represents only the useful data content of a short node, with the
+// caches and flags stripped out to minimize its data storage. This type honors
+// the same RLP encoding as the original parent.
+type rawShortNode struct {
+ Key []byte
+ Val node
+}
+
+func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
+func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") }
+
+// cachedNode is all the information we know about a single cached trie node
+// in the memory database write layer.
+type cachedNode struct {
+ node node // Cached collapsed trie node, or raw rlp data
+ size uint16 // Byte size of the useful cached data
+
+ parents uint32 // Number of live nodes referencing this one
+ children map[common.Hash]uint16 // External children referenced by this node
+
+ flushPrev common.Hash // Previous node in the flush-list
+ flushNext common.Hash // Next node in the flush-list
+}
+
+// cachedNodeSize is the raw size of a cachedNode data structure without any
+// node data included. It's an approximate size, but should be a lot better
+// than not counting them.
+var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
+
+// cachedNodeChildrenSize is the raw size of an initialized but empty external
+// reference map.
+const cachedNodeChildrenSize = 48
+
+// rlp returns the raw rlp encoded blob of the cached trie node, either directly
+// from the cache, or by regenerating it from the collapsed node.
+func (n *cachedNode) rlp() []byte {
+ if node, ok := n.node.(rawNode); ok {
+ return node
+ }
+ return nodeToBytes(n.node)
+}
+
+// obj returns the decoded and expanded trie node, either directly from the cache,
+// or by regenerating it from the rlp encoded blob.
+func (n *cachedNode) obj(hash common.Hash) node {
+ if node, ok := n.node.(rawNode); ok {
+ // The raw-blob format nodes are loaded either from the
+ // clean cache or the database, they are all in their own
+ // copy and safe to use unsafe decoder.
+ return mustDecodeNodeUnsafe(hash[:], node)
+ }
+ return expandNode(hash[:], n.node)
+}
+
+// forChilds invokes the callback for all the tracked children of this node,
+// both the implicit ones from inside the node as well as the explicit ones
+// from outside the node.
+func (n *cachedNode) forChilds(onChild func(hash common.Hash)) {
+ for child := range n.children {
+ onChild(child)
+ }
+ if _, ok := n.node.(rawNode); !ok {
+ forGatherChildren(n.node, onChild)
+ }
+}
+
+// forGatherChildren traverses the node hierarchy of a collapsed storage node and
+// invokes the callback for all the hashnode children.
+func forGatherChildren(n node, onChild func(hash common.Hash)) {
+ switch n := n.(type) {
+ case *rawShortNode:
+ forGatherChildren(n.Val, onChild)
+ case rawFullNode:
+ for i := 0; i < 16; i++ {
+ forGatherChildren(n[i], onChild)
+ }
+ case hashNode:
+ onChild(common.BytesToHash(n))
+ case valueNode, nil, rawNode:
+ default:
+ panic(fmt.Sprintf("unknown node type: %T", n))
+ }
+}
+
+// simplifyNode traverses the hierarchy of an expanded memory node and discards
+// all the internal caches, returning a node that only contains the raw data.
+func simplifyNode(n node) node {
+ switch n := n.(type) {
+ case *shortNode:
+ // Short nodes discard the flags and cascade
+ return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)}
+
+ case *fullNode:
+ // Full nodes discard the flags and cascade
+ node := rawFullNode(n.Children)
+ for i := 0; i < len(node); i++ {
+ if node[i] != nil {
+ node[i] = simplifyNode(node[i])
+ }
+ }
+ return node
+
+ case valueNode, hashNode, rawNode:
+ return n
+
+ default:
+ panic(fmt.Sprintf("unknown node type: %T", n))
+ }
+}
+
+// expandNode traverses the node hierarchy of a collapsed storage node and converts
+// all fields and keys into expanded memory form.
+func expandNode(hash hashNode, n node) node {
+ switch n := n.(type) {
+ case *rawShortNode:
+ // Short nodes need key and child expansion
+ return &shortNode{
+ Key: compactToHex(n.Key),
+ Val: expandNode(nil, n.Val),
+ flags: nodeFlag{
+ hash: hash,
+ },
+ }
+
+ case rawFullNode:
+ // Full nodes need child expansion
+ node := &fullNode{
+ flags: nodeFlag{
+ hash: hash,
+ },
+ }
+ for i := 0; i < len(node.Children); i++ {
+ if n[i] != nil {
+ node.Children[i] = expandNode(nil, n[i])
+ }
+ }
+ return node
+
+ case valueNode, hashNode:
+ return n
+
+ default:
+ panic(fmt.Sprintf("unknown node type: %T", n))
+ }
+}
diff --git a/trie_by_cid/trie/database_test.go b/trie_by_cid/trie/database_test.go
index a800135..5a6e36d 100644
--- a/trie_by_cid/trie/database_test.go
+++ b/trie_by_cid/trie/database_test.go
@@ -14,21 +14,20 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package trie_test
+package trie
import (
"testing"
- "github.com/ethereum/go-ethereum/ethdb/memorydb"
-
- "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
)
// Tests that the trie database returns a missing trie node error if attempting
// to retrieve the meta root.
func TestDatabaseMetarootFetch(t *testing.T) {
- db := trie.NewDatabase(memorydb.New())
- if _, err := db.Node(trie.CidKey{}); err == nil {
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ if _, err := db.Node(common.Hash{}, StateTrieCodec); err == nil {
t.Fatalf("metaroot retrieval succeeded")
}
}
diff --git a/trie_by_cid/trie/encoding.go b/trie_by_cid/trie/encoding.go
index 381883a..ace4570 100644
--- a/trie_by_cid/trie/encoding.go
+++ b/trie_by_cid/trie/encoding.go
@@ -16,8 +16,81 @@
package trie
+// Trie keys are dealt with in three distinct encodings:
+//
+// KEYBYTES encoding contains the actual key and nothing else. This encoding is the
+// input to most API functions.
+//
+// HEX encoding contains one byte for each nibble of the key and an optional trailing
+// 'terminator' byte of value 0x10 which indicates whether or not the node at the key
+// contains a value. Hex key encoding is used for nodes loaded in memory because it's
+// convenient to access.
+//
+// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix
+// encoding" there) and contains the bytes of the key and a flag. The high nibble of the
+// first byte contains the flag; the lowest bit encoding the oddness of the length and
+// the second-lowest encoding whether the node at the key is a value node. The low nibble
+// of the first byte is zero in the case of an even number of nibbles and the first nibble
+// in the case of an odd number. All remaining nibbles (now an even number) fit properly
+// into the remaining bytes. Compact encoding is used for nodes stored on disk.
+
+// HexToCompact converts a hex path to the compact encoded format
+func HexToCompact(hex []byte) []byte {
+ return hexToCompact(hex)
+}
+
+func hexToCompact(hex []byte) []byte {
+ terminator := byte(0)
+ if hasTerm(hex) {
+ terminator = 1
+ hex = hex[:len(hex)-1]
+ }
+ buf := make([]byte, len(hex)/2+1)
+ buf[0] = terminator << 5 // the flag byte
+ if len(hex)&1 == 1 {
+ buf[0] |= 1 << 4 // odd flag
+ buf[0] |= hex[0] // first nibble is contained in the first byte
+ hex = hex[1:]
+ }
+ decodeNibbles(hex, buf[1:])
+ return buf
+}
+
+// hexToCompactInPlace places the compact key in input buffer, returning the length
+// needed for the representation
+func hexToCompactInPlace(hex []byte) int {
+ var (
+ hexLen = len(hex) // length of the hex input
+ firstByte = byte(0)
+ )
+ // Check if we have a terminator there
+ if hexLen > 0 && hex[hexLen-1] == 16 {
+ firstByte = 1 << 5
+ hexLen-- // last part was the terminator, ignore that
+ }
+ var (
+ binLen = hexLen/2 + 1
+ ni = 0 // index in hex
+ bi = 1 // index in bin (compact)
+ )
+ if hexLen&1 == 1 {
+ firstByte |= 1 << 4 // odd flag
+ firstByte |= hex[0] // first nibble is contained in the first byte
+ ni++
+ }
+ for ; ni < hexLen; bi, ni = bi+1, ni+2 {
+ hex[bi] = hex[ni]<<4 | hex[ni+1]
+ }
+ hex[0] = firstByte
+ return binLen
+}
+
// CompactToHex converts a compact encoded path to hex format
func CompactToHex(compact []byte) []byte {
+ return compactToHex(compact)
+}
+
+func compactToHex(compact []byte) []byte {
if len(compact) == 0 {
return compact
}
@@ -62,6 +135,20 @@ func decodeNibbles(nibbles []byte, bytes []byte) {
}
}
+// prefixLen returns the length of the common prefix of a and b.
+func prefixLen(a, b []byte) int {
+ var i, length = 0, len(a)
+ if len(b) < length {
+ length = len(b)
+ }
+ for ; i < length; i++ {
+ if a[i] != b[i] {
+ break
+ }
+ }
+ return i
+}
+
// hasTerm returns whether a hex key has the terminator flag.
func hasTerm(s []byte) bool {
return len(s) > 0 && s[len(s)-1] == 16
diff --git a/trie_by_cid/trie/encoding_test.go b/trie_by_cid/trie/encoding_test.go
new file mode 100644
index 0000000..abc1e9d
--- /dev/null
+++ b/trie_by_cid/trie/encoding_test.go
@@ -0,0 +1,141 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/hex"
+ "math/rand"
+ "testing"
+)
+
+func TestHexCompact(t *testing.T) {
+ tests := []struct{ hex, compact []byte }{
+ // empty keys, with and without terminator.
+ {hex: []byte{}, compact: []byte{0x00}},
+ {hex: []byte{16}, compact: []byte{0x20}},
+ // odd length, no terminator
+ {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}},
+ // even length, no terminator
+ {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}},
+ // odd length, terminator
+ {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}},
+ // even length, terminator
+ {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}},
+ }
+ for _, test := range tests {
+ if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) {
+ t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact)
+ }
+ if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) {
+ t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex)
+ }
+ }
+}
+
+func TestHexKeybytes(t *testing.T) {
+ tests := []struct{ key, hexIn, hexOut []byte }{
+ {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}},
+ {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}},
+ {
+ key: []byte{0x12, 0x34, 0x56},
+ hexIn: []byte{1, 2, 3, 4, 5, 6, 16},
+ hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
+ },
+ {
+ key: []byte{0x12, 0x34, 0x5},
+ hexIn: []byte{1, 2, 3, 4, 0, 5, 16},
+ hexOut: []byte{1, 2, 3, 4, 0, 5, 16},
+ },
+ {
+ key: []byte{0x12, 0x34, 0x56},
+ hexIn: []byte{1, 2, 3, 4, 5, 6},
+ hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
+ },
+ }
+ for _, test := range tests {
+ if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
+ t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
+ }
+ if k := hexToKeyBytes(test.hexIn); !bytes.Equal(k, test.key) {
+ t.Errorf("hexToKeyBytes(%x) -> %x, want %x", test.hexIn, k, test.key)
+ }
+ }
+}
+
+func TestHexToCompactInPlace(t *testing.T) {
+ for i, key := range []string{
+ "00",
+ "060a040c0f000a090b040803010801010900080d090a0a0d0903000b10",
+ "10",
+ } {
+ hexBytes, _ := hex.DecodeString(key)
+ exp := hexToCompact(hexBytes)
+ sz := hexToCompactInPlace(hexBytes)
+ got := hexBytes[:sz]
+ if !bytes.Equal(exp, got) {
+ t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp)
+ }
+ }
+}
+
+func TestHexToCompactInPlaceRandom(t *testing.T) {
+ for i := 0; i < 10000; i++ {
+ l := rand.Intn(128)
+ key := make([]byte, l)
+ crand.Read(key)
+ hexBytes := keybytesToHex(key)
+ hexOrig := []byte(string(hexBytes))
+ exp := hexToCompact(hexBytes)
+ sz := hexToCompactInPlace(hexBytes)
+ got := hexBytes[:sz]
+
+ if !bytes.Equal(exp, got) {
+ t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n",
+ key, hexOrig, got, exp)
+ }
+ }
+}
+
+func BenchmarkHexToCompact(b *testing.B) {
+ testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
+ for i := 0; i < b.N; i++ {
+ hexToCompact(testBytes)
+ }
+}
+
+func BenchmarkCompactToHex(b *testing.B) {
+ testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
+ for i := 0; i < b.N; i++ {
+ compactToHex(testBytes)
+ }
+}
+
+func BenchmarkKeybytesToHex(b *testing.B) {
+ testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
+ for i := 0; i < b.N; i++ {
+ keybytesToHex(testBytes)
+ }
+}
+
+func BenchmarkHexToKeybytes(b *testing.B) {
+ testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
+ for i := 0; i < b.N; i++ {
+ hexToKeyBytes(testBytes)
+ }
+}
diff --git a/trie_by_cid/trie/errors.go b/trie_by_cid/trie/errors.go
index 3881487..afe344b 100644
--- a/trie_by_cid/trie/errors.go
+++ b/trie_by_cid/trie/errors.go
@@ -27,7 +27,7 @@ import (
// information necessary for retrieving the missing node.
type MissingNodeError struct {
Owner common.Hash // owner of the trie if it's 2-layered trie
- NodeHash []byte // hash of the missing node
+ NodeHash common.Hash // hash of the missing node
Path []byte // hex-encoded path to the missing node
err error // concrete error for missing trie node
}
diff --git a/trie_by_cid/trie/hasher.go b/trie_by_cid/trie/hasher.go
index caa80f0..e594d6d 100644
--- a/trie_by_cid/trie/hasher.go
+++ b/trie_by_cid/trie/hasher.go
@@ -21,7 +21,6 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/trie"
"golang.org/x/crypto/sha3"
)
@@ -99,7 +98,7 @@ func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNo
// Previously, we did copy this one. We don't seem to need to actually
// do that, since we don't overwrite/reuse keys
//cached.Key = common.CopyBytes(n.Key)
- collapsed.Key = trie.HexToCompact(n.Key)
+ collapsed.Key = hexToCompact(n.Key)
// Unless the child is a valuenode or hashnode, hash it
switch n.Val.(type) {
case *fullNode, *shortNode:
@@ -171,8 +170,8 @@ func (h *hasher) fullnodeToHash(n *fullNode, force bool) node {
//
// All node encoding must be done like this:
//
-// node.encode(h.encbuf)
-// enc := h.encodedBytes()
+// node.encode(h.encbuf)
+// enc := h.encodedBytes()
//
// This convention exists because node.encode can only be inlined/escape-analyzed when
// called on a concrete receiver type.
diff --git a/trie_by_cid/trie/iterator.go b/trie_by_cid/trie/iterator.go
index 55e7a96..de506eb 100644
--- a/trie_by_cid/trie/iterator.go
+++ b/trie_by_cid/trie/iterator.go
@@ -18,14 +18,26 @@ package trie
import (
"bytes"
+ "container/heap"
"errors"
+ "time"
+
+ "github.com/ethereum/go-ethereum/statediff/indexer/database/metrics"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/trie"
+ "github.com/ethereum/go-ethereum/core/types"
+ gethtrie "github.com/ethereum/go-ethereum/trie"
)
-// NodeIterator is a re-export of the go-ethereum interface
-type NodeIterator = trie.NodeIterator
+// NodeIterator is an iterator to traverse the trie pre-order.
+type NodeIterator = gethtrie.NodeIterator
+
+// NodeResolver is used for looking up trie nodes before reaching into the real
+// persistent layer. This is not mandatory, rather is an optimization for cases
+// where trie nodes can be recovered from some external mechanism without reading
+// from disk. In those cases, this resolver allows short circuiting accesses and
+// returning them from memory.
+type NodeResolver = gethtrie.NodeResolver
// Iterator is a key-value trie iterator that traverses a Trie.
type Iterator struct {
@@ -82,7 +94,7 @@ type nodeIterator struct {
path []byte // Path to the current node
err error // Failure set in case of an internal error in the iterator
- resolver trie.NodeResolver // Optional intermediate resolver above the disk layer
+ resolver NodeResolver // optional node resolver for avoiding disk hits
}
// errIteratorEnd is stored in nodeIterator.err when iteration is done.
@@ -99,7 +111,7 @@ func (e seekError) Error() string {
}
func newNodeIterator(trie *Trie, start []byte) NodeIterator {
- if trie.Hash() == emptyRoot {
+ if trie.Hash() == types.EmptyRootHash {
return &nodeIterator{
trie: trie,
err: errIteratorEnd,
@@ -110,7 +122,7 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator {
return it
}
-func (it *nodeIterator) AddResolver(resolver trie.NodeResolver) {
+func (it *nodeIterator) AddResolver(resolver NodeResolver) {
it.resolver = resolver
}
@@ -128,6 +140,14 @@ func (it *nodeIterator) Parent() common.Hash {
return it.stack[len(it.stack)-1].parent
}
+func (it *nodeIterator) ParentPath() []byte {
+ if len(it.stack) == 0 {
+ return []byte{}
+ }
+ pathlen := it.stack[len(it.stack)-1].pathlen
+ return it.path[:pathlen]
+}
+
func (it *nodeIterator) Leaf() bool {
return hasTerm(it.path)
}
@@ -241,7 +261,7 @@ func (it *nodeIterator) seek(prefix []byte) error {
func (it *nodeIterator) init() (*nodeIteratorState, error) {
root := it.trie.Hash()
state := &nodeIteratorState{node: it.trie.root, index: -1}
- if root != emptyRoot {
+ if root != types.EmptyRootHash {
state.hash = root
}
return state, state.resolve(it, nil)
@@ -320,7 +340,12 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) {
}
}
}
- return it.trie.resolveHash(hash, path)
+ // Retrieve the specified node from the underlying node reader.
+ // it.trie.resolveAndTrack is not used since in that function the
+ // loaded blob will be tracked, while it's not required here since
+ // all loaded nodes won't be linked to trie at all and track nodes
+ // may lead to out-of-memory issue.
+ return it.trie.reader.node(path, common.BytesToHash(hash))
}
func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) {
@@ -329,7 +354,12 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error)
return blob, nil
}
}
- return it.trie.resolveBlob(hash, path)
+ // Retrieve the specified node from the underlying node reader.
+ // it.trie.resolveAndTrack is not used since in that function the
+ // loaded blob will be tracked, while it's not required here since
+ // all loaded nodes won't be linked to trie at all and track nodes
+ // may lead to out-of-memory issue.
+ return it.trie.reader.nodeBlob(path, common.BytesToHash(hash))
}
func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
@@ -455,3 +485,248 @@ func (it *nodeIterator) pop() {
it.stack[len(it.stack)-1] = nil
it.stack = it.stack[:len(it.stack)-1]
}
+
+func compareNodes(a, b NodeIterator) int {
+ if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
+ return cmp
+ }
+ if a.Leaf() && !b.Leaf() {
+ return -1
+ } else if b.Leaf() && !a.Leaf() {
+ return 1
+ }
+ if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
+ return cmp
+ }
+ if a.Leaf() && b.Leaf() {
+ return bytes.Compare(a.LeafBlob(), b.LeafBlob())
+ }
+ return 0
+}
+
+type differenceIterator struct {
+ a, b NodeIterator // Nodes returned are those in b - a.
+ eof bool // Indicates a has run out of elements
+ count int // Number of nodes scanned on either trie
+}
+
+// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that
+// are not in a. Returns the iterator, and a pointer to an integer recording the number
+// of nodes seen.
+func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) {
+ a.Next(true)
+ it := &differenceIterator{
+ a: a,
+ b: b,
+ }
+ return it, &it.count
+}
+
+func (it *differenceIterator) Hash() common.Hash {
+ return it.b.Hash()
+}
+
+func (it *differenceIterator) Parent() common.Hash {
+ return it.b.Parent()
+}
+
+func (it *differenceIterator) ParentPath() []byte {
+ return it.b.ParentPath()
+}
+
+func (it *differenceIterator) Leaf() bool {
+ return it.b.Leaf()
+}
+
+func (it *differenceIterator) LeafKey() []byte {
+ return it.b.LeafKey()
+}
+
+func (it *differenceIterator) LeafBlob() []byte {
+ return it.b.LeafBlob()
+}
+
+func (it *differenceIterator) LeafProof() [][]byte {
+ return it.b.LeafProof()
+}
+
+func (it *differenceIterator) Path() []byte {
+ return it.b.Path()
+}
+
+func (it *differenceIterator) NodeBlob() []byte {
+ return it.b.NodeBlob()
+}
+
+func (it *differenceIterator) AddResolver(resolver NodeResolver) {
+ panic("not implemented")
+}
+
+func (it *differenceIterator) Next(bool) bool {
+ defer metrics.UpdateDuration(time.Now(), metrics.IndexerMetrics.DifferenceIteratorNextTimer)
+ // Invariants:
+ // - We always advance at least one element in b.
+ // - At the start of this function, a's path is lexically greater than b's.
+ if !it.b.Next(true) {
+ return false
+ }
+ it.count++
+
+ if it.eof {
+ // a has reached eof, so we just return all elements from b
+ return true
+ }
+
+ for {
+ switch compareNodes(it.a, it.b) {
+ case -1:
+ // b jumped past a; advance a
+ if !it.a.Next(true) {
+ it.eof = true
+ return true
+ }
+ it.count++
+ case 1:
+ // b is before a
+ return true
+ case 0:
+ // a and b are identical; skip this whole subtree if the nodes have hashes
+ hasHash := it.a.Hash() == common.Hash{}
+ if !it.b.Next(hasHash) {
+ return false
+ }
+ it.count++
+ if !it.a.Next(hasHash) {
+ it.eof = true
+ return true
+ }
+ it.count++
+ }
+ }
+}
+
+func (it *differenceIterator) Error() error {
+ if err := it.a.Error(); err != nil {
+ return err
+ }
+ return it.b.Error()
+}
+
+type nodeIteratorHeap []NodeIterator
+
+func (h nodeIteratorHeap) Len() int { return len(h) }
+func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 }
+func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) }
+func (h *nodeIteratorHeap) Pop() interface{} {
+ n := len(*h)
+ x := (*h)[n-1]
+ *h = (*h)[0 : n-1]
+ return x
+}
+
+type unionIterator struct {
+ items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
+ count int // Number of nodes scanned across all tries
+}
+
+// NewUnionIterator constructs a NodeIterator that iterates over elements in the union
+// of the provided NodeIterators. Returns the iterator, and a pointer to an integer
+// recording the number of nodes visited.
+func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
+ h := make(nodeIteratorHeap, len(iters))
+ copy(h, iters)
+ heap.Init(&h)
+
+ ui := &unionIterator{items: &h}
+ return ui, &ui.count
+}
+
+func (it *unionIterator) Hash() common.Hash {
+ return (*it.items)[0].Hash()
+}
+
+func (it *unionIterator) Parent() common.Hash {
+ return (*it.items)[0].Parent()
+}
+
+func (it *unionIterator) ParentPath() []byte {
+ return (*it.items)[0].ParentPath()
+}
+
+func (it *unionIterator) Leaf() bool {
+ return (*it.items)[0].Leaf()
+}
+
+func (it *unionIterator) LeafKey() []byte {
+ return (*it.items)[0].LeafKey()
+}
+
+func (it *unionIterator) LeafBlob() []byte {
+ return (*it.items)[0].LeafBlob()
+}
+
+func (it *unionIterator) LeafProof() [][]byte {
+ return (*it.items)[0].LeafProof()
+}
+
+func (it *unionIterator) Path() []byte {
+ return (*it.items)[0].Path()
+}
+
+func (it *unionIterator) NodeBlob() []byte {
+ return (*it.items)[0].NodeBlob()
+}
+
+func (it *unionIterator) AddResolver(resolver NodeResolver) {
+ panic("not implemented")
+}
+
+// Next returns the next node in the union of tries being iterated over.
+//
+// It does this by maintaining a heap of iterators, sorted by the iteration
+// order of their next elements, with one entry for each source trie. Each
+// time Next() is called, it takes the least element from the heap to return,
+// advancing any other iterators that also point to that same element. These
+// iterators are called with descend=false, since we know that any nodes under
+// these nodes will also be duplicates, found in the currently selected iterator.
+// Whenever an iterator is advanced, it is pushed back into the heap if it still
+// has elements remaining.
+//
+// In the case that descend=false - eg, we're asked to ignore all subnodes of the
+// current node - we also advance any iterators in the heap that have the current
+// path as a prefix.
+func (it *unionIterator) Next(descend bool) bool {
+ if len(*it.items) == 0 {
+ return false
+ }
+
+ // Get the next key from the union
+ least := heap.Pop(it.items).(NodeIterator)
+
+ // Skip over other nodes as long as they're identical, or, if we're not descending, as
+ // long as they have the same prefix as the current node.
+ for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) {
+ skipped := heap.Pop(it.items).(NodeIterator)
+ // Skip the whole subtree if the nodes have hashes; otherwise just skip this node
+ if skipped.Next(skipped.Hash() == common.Hash{}) {
+ it.count++
+ // If there are more elements, push the iterator back on the heap
+ heap.Push(it.items, skipped)
+ }
+ }
+ if least.Next(descend) {
+ it.count++
+ heap.Push(it.items, least)
+ }
+ return len(*it.items) > 0
+}
+
+func (it *unionIterator) Error() error {
+ for i := 0; i < len(*it.items); i++ {
+ if err := (*it.items)[i].Error(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/trie_by_cid/trie/iterator_test.go b/trie_by_cid/trie/iterator_test.go
index f7f0103..8b606ef 100644
--- a/trie_by_cid/trie/iterator_test.go
+++ b/trie_by_cid/trie/iterator_test.go
@@ -14,28 +14,24 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package trie_test
+package trie
import (
"bytes"
- "context"
+ "encoding/binary"
"fmt"
+ "math/rand"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
- "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
+ "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/ethdb/memorydb"
geth_trie "github.com/ethereum/go-ethereum/trie"
-
- "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
)
var (
- dbConfig, _ = postgres.DefaultConfig.WithEnv()
- trieConfig = trie.Config{Cache: 256}
- ctx = context.Background()
-
- testdata0 = []kvs{
+ packableTestData = []kvsi{
{"one", 1},
{"two", 2},
{"three", 3},
@@ -43,20 +39,10 @@ var (
{"five", 5},
{"ten", 10},
}
- testdata1 = []kvs{
- {"barb", 0},
- {"bard", 1},
- {"bars", 2},
- {"bar", 3},
- {"fab", 4},
- {"food", 5},
- {"foos", 6},
- {"foo", 7},
- }
)
func TestEmptyIterator(t *testing.T) {
- trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase()))
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
iter := trie.NodeIterator(nil)
seen := make(map[string]struct{})
@@ -69,63 +55,163 @@ func TestEmptyIterator(t *testing.T) {
}
func TestIterator(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- origtrie := geth_trie.NewEmpty(db)
- all, err := updateTrie(origtrie, testdata0)
- if err != nil {
- t.Fatal(err)
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ trie := NewEmpty(db)
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
}
- // commit and index data
- root := commitTrie(t, db, origtrie)
- tr := indexTrie(t, edb, root)
+ all := make(map[string]string)
+ for _, val := range vals {
+ all[val.k] = val.v
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ root, nodes := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
- found := make(map[string]int64)
- it := trie.NewIterator(tr.NodeIterator(nil))
+ trie, _ = New(TrieID(root), db, StateTrieCodec)
+ found := make(map[string]string)
+ it := NewIterator(trie.NodeIterator(nil))
for it.Next() {
- found[string(it.Key)] = unpackValue(it.Value)
+ found[string(it.Key)] = string(it.Value)
}
- if len(found) != len(all) {
- t.Errorf("number of iterated values do not match: want %d, found %d", len(all), len(found))
- }
- for k, kv := range all {
- if found[k] != kv.v {
- t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], kv.v)
+ for k, v := range all {
+ if found[k] != v {
+ t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
}
}
}
-func TestIteratorSeek(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase()))
- if _, err := updateTrie(orig, testdata1); err != nil {
- t.Fatal(err)
+type kv struct {
+ k, v []byte
+ t bool
+}
+
+func TestIteratorLargeData(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ vals := make(map[string]*kv)
+
+ for i := byte(0); i < 255; i++ {
+ value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ trie.Update(value2.k, value2.v)
+ vals[string(value.k)] = value
+ vals[string(value2.k)] = value2
+ }
+
+ it := NewIterator(trie.NodeIterator(nil))
+ for it.Next() {
+ vals[string(it.Key)].t = true
+ }
+
+ var untouched []*kv
+ for _, value := range vals {
+ if !value.t {
+ untouched = append(untouched, value)
+ }
+ }
+
+ if len(untouched) > 0 {
+ t.Errorf("Missed %d nodes", len(untouched))
+ for _, value := range untouched {
+ t.Error(value)
+ }
+ }
+}
+
+// Tests that the node iterator indeed walks over the entire database contents.
+func TestNodeIteratorCoverage(t *testing.T) {
+ db, trie, _ := makeTestTrie(t)
+ // Create some arbitrary test trie to iterate
+
+ // Gather all the node hashes found by the iterator
+ hashes := make(map[common.Hash]struct{})
+ for it := trie.NodeIterator(nil); it.Next(true); {
+ if it.Hash() != (common.Hash{}) {
+ hashes[it.Hash()] = struct{}{}
+ }
+ }
+ // Cross check the hashes and the database itself
+ for hash := range hashes {
+ if _, err := db.Node(hash, StateTrieCodec); err != nil {
+ t.Errorf("failed to retrieve reported node %x: %v", hash, err)
+ }
+ }
+ for hash, obj := range db.dirties {
+ if obj != nil && hash != (common.Hash{}) {
+ if _, ok := hashes[hash]; !ok {
+ t.Errorf("state entry not reported %x", hash)
+ }
+ }
+ }
+ it := db.diskdb.NewIterator(nil, nil)
+ for it.Next() {
+ key := it.Key()
+ if _, ok := hashes[common.BytesToHash(key)]; !ok {
+ t.Errorf("state entry not reported %x", key)
+ }
+ }
+ it.Release()
+}
+
+type kvs struct{ k, v string }
+
+var testdata1 = []kvs{
+ {"barb", "ba"},
+ {"bard", "bc"},
+ {"bars", "bb"},
+ {"bar", "b"},
+ {"fab", "z"},
+ {"food", "ab"},
+ {"foos", "aa"},
+ {"foo", "a"},
+}
+
+var testdata2 = []kvs{
+ {"aardvark", "c"},
+ {"bar", "b"},
+ {"barb", "bd"},
+ {"bars", "be"},
+ {"fab", "z"},
+ {"foo", "a"},
+ {"foos", "aa"},
+ {"food", "ab"},
+ {"jars", "d"},
+}
+
+func TestIteratorSeek(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ for _, val := range testdata1 {
+ trie.Update([]byte(val.k), []byte(val.v))
}
- root := commitTrie(t, db, orig)
- tr := indexTrie(t, edb, root)
// Seek to the middle.
- it := trie.NewIterator(tr.NodeIterator([]byte("fab")))
+ it := NewIterator(trie.NodeIterator([]byte("fab")))
if err := checkIteratorOrder(testdata1[4:], it); err != nil {
t.Fatal(err)
}
// Seek to a non-existent key.
- it = trie.NewIterator(tr.NodeIterator([]byte("barc")))
+ it = NewIterator(trie.NodeIterator([]byte("barc")))
if err := checkIteratorOrder(testdata1[1:], it); err != nil {
t.Fatal(err)
}
// Seek beyond the end.
- it = trie.NewIterator(tr.NodeIterator([]byte("z")))
+ it = NewIterator(trie.NodeIterator([]byte("z")))
if err := checkIteratorOrder(nil, it); err != nil {
t.Fatal(err)
}
}
-func checkIteratorOrder(want []kvs, it *trie.Iterator) error {
+func checkIteratorOrder(want []kvs, it *Iterator) error {
for it.Next() {
if len(want) == 0 {
return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
@@ -141,11 +227,366 @@ func checkIteratorOrder(want []kvs, it *trie.Iterator) error {
return nil
}
+func TestDifferenceIterator(t *testing.T) {
+ dba := NewDatabase(rawdb.NewMemoryDatabase())
+ triea := NewEmpty(dba)
+ for _, val := range testdata1 {
+ triea.Update([]byte(val.k), []byte(val.v))
+ }
+ rootA, nodesA := triea.Commit(false)
+ dba.Update(NewWithNodeSet(nodesA))
+ triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
+
+ dbb := NewDatabase(rawdb.NewMemoryDatabase())
+ trieb := NewEmpty(dbb)
+ for _, val := range testdata2 {
+ trieb.Update([]byte(val.k), []byte(val.v))
+ }
+ rootB, nodesB := trieb.Commit(false)
+ dbb.Update(NewWithNodeSet(nodesB))
+ trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
+
+ found := make(map[string]string)
+ di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
+ it := NewIterator(di)
+ for it.Next() {
+ found[string(it.Key)] = string(it.Value)
+ }
+
+ all := []struct{ k, v string }{
+ {"aardvark", "c"},
+ {"barb", "bd"},
+ {"bars", "be"},
+ {"jars", "d"},
+ }
+ for _, item := range all {
+ if found[item.k] != item.v {
+ t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
+ }
+ }
+ if len(found) != len(all) {
+ t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
+ }
+}
+
+func TestUnionIterator(t *testing.T) {
+ dba := NewDatabase(rawdb.NewMemoryDatabase())
+ triea := NewEmpty(dba)
+ for _, val := range testdata1 {
+ triea.Update([]byte(val.k), []byte(val.v))
+ }
+ rootA, nodesA := triea.Commit(false)
+ dba.Update(NewWithNodeSet(nodesA))
+ triea, _ = New(TrieID(rootA), dba, StateTrieCodec)
+
+ dbb := NewDatabase(rawdb.NewMemoryDatabase())
+ trieb := NewEmpty(dbb)
+ for _, val := range testdata2 {
+ trieb.Update([]byte(val.k), []byte(val.v))
+ }
+ rootB, nodesB := trieb.Commit(false)
+ dbb.Update(NewWithNodeSet(nodesB))
+ trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec)
+
+ di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
+ it := NewIterator(di)
+
+ all := []struct{ k, v string }{
+ {"aardvark", "c"},
+ {"barb", "ba"},
+ {"barb", "bd"},
+ {"bard", "bc"},
+ {"bars", "bb"},
+ {"bars", "be"},
+ {"bar", "b"},
+ {"fab", "z"},
+ {"food", "ab"},
+ {"foos", "aa"},
+ {"foo", "a"},
+ {"jars", "d"},
+ }
+
+ for i, kv := range all {
+ if !it.Next() {
+ t.Errorf("Iterator ends prematurely at element %d", i)
+ }
+ if kv.k != string(it.Key) {
+ t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
+ }
+ if kv.v != string(it.Value) {
+ t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
+ }
+ }
+ if it.Next() {
+ t.Errorf("Iterator returned extra values.")
+ }
+}
+
+func TestIteratorNoDups(t *testing.T) {
+ tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ for _, val := range testdata1 {
+ tr.Update([]byte(val.k), []byte(val.v))
+ }
+ checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
+}
+
+// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
+func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
+
+func testIteratorContinueAfterError(t *testing.T, memonly bool) {
+ diskdb := rawdb.NewMemoryDatabase()
+ triedb := NewDatabase(diskdb)
+
+ tr := NewEmpty(triedb)
+ for _, val := range testdata1 {
+ tr.Update([]byte(val.k), []byte(val.v))
+ }
+ _, nodes := tr.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ // if !memonly {
+ // triedb.Commit(tr.Hash(), false)
+ // }
+ wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
+
+ var (
+ diskKeys [][]byte
+ memKeys []common.Hash
+ )
+ if memonly {
+ memKeys = triedb.Nodes()
+ } else {
+ it := diskdb.NewIterator(nil, nil)
+ for it.Next() {
+ diskKeys = append(diskKeys, it.Key())
+ }
+ it.Release()
+ }
+ for i := 0; i < 20; i++ {
+ // Create trie that will load all nodes from DB.
+ tr, _ := New(TrieID(tr.Hash()), triedb, StateTrieCodec)
+
+ // Remove a random node from the database. It can't be the root node
+ // because that one is already loaded.
+ var (
+ rkey common.Hash
+ rval []byte
+ robj *cachedNode
+ )
+ for {
+ if memonly {
+ rkey = memKeys[rand.Intn(len(memKeys))]
+ } else {
+ copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
+ }
+ if rkey != tr.Hash() {
+ break
+ }
+ }
+ if memonly {
+ robj = triedb.dirties[rkey]
+ delete(triedb.dirties, rkey)
+ } else {
+ rval, _ = diskdb.Get(rkey[:])
+ diskdb.Delete(rkey[:])
+ }
+ // Iterate until the error is hit.
+ seen := make(map[string]bool)
+ it := tr.NodeIterator(nil)
+ checkIteratorNoDups(t, it, seen)
+ missing, ok := it.Error().(*MissingNodeError)
+ if !ok || missing.NodeHash != rkey {
+ t.Fatal("didn't hit missing node, got", it.Error())
+ }
+
+ // Add the node back and continue iteration.
+ if memonly {
+ triedb.dirties[rkey] = robj
+ } else {
+ diskdb.Put(rkey[:], rval)
+ }
+ checkIteratorNoDups(t, it, seen)
+ if it.Error() != nil {
+ t.Fatal("unexpected error", it.Error())
+ }
+ if len(seen) != wantNodeCount {
+ t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
+ }
+ }
+}
+
+// Similar to the test above, this one checks that failure to create nodeIterator at a
+// certain key prefix behaves correctly when Next is called. The expectation is that Next
+// should retry seeking before returning true for the first time.
+func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
+ testIteratorContinueAfterSeekError(t, true)
+}
+
+func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
+ // Commit test trie to db, then remove the node containing "bars".
+ diskdb := rawdb.NewMemoryDatabase()
+ triedb := NewDatabase(diskdb)
+
+ ctr := NewEmpty(triedb)
+ for _, val := range testdata1 {
+ ctr.Update([]byte(val.k), []byte(val.v))
+ }
+ root, nodes := ctr.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ // if !memonly {
+ // triedb.Commit(root, false)
+ // }
+ barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
+ var (
+ barNodeBlob []byte
+ barNodeObj *cachedNode
+ )
+ if memonly {
+ barNodeObj = triedb.dirties[barNodeHash]
+ delete(triedb.dirties, barNodeHash)
+ } else {
+ barNodeBlob, _ = diskdb.Get(barNodeHash[:])
+ diskdb.Delete(barNodeHash[:])
+ }
+ // Create a new iterator that seeks to "bars". Seeking can't proceed because
+ // the node is missing.
+ tr, _ := New(TrieID(root), triedb, StateTrieCodec)
+ it := tr.NodeIterator([]byte("bars"))
+ missing, ok := it.Error().(*MissingNodeError)
+ if !ok {
+ t.Fatal("want MissingNodeError, got", it.Error())
+ } else if missing.NodeHash != barNodeHash {
+ t.Fatal("wrong node missing")
+ }
+ // Reinsert the missing node.
+ if memonly {
+ triedb.dirties[barNodeHash] = barNodeObj
+ } else {
+ diskdb.Put(barNodeHash[:], barNodeBlob)
+ }
+ // Check that iteration produces the right set of values.
+ if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
+ if seen == nil {
+ seen = make(map[string]bool)
+ }
+ for it.Next(true) {
+ if seen[string(it.Path())] {
+ t.Fatalf("iterator visited node path %x twice", it.Path())
+ }
+ seen[string(it.Path())] = true
+ }
+ return len(seen)
+}
+
+type loggingTrieDb struct {
+ *Database
+ getCount uint64
+}
+
+// GetReader retrieves a node reader belonging to the given state root.
+func (db *loggingTrieDb) GetReader(root common.Hash, codec uint64) Reader {
+ return &loggingNodeReader{db, codec}
+}
+
+// hashReader is reader of hashDatabase which implements the Reader interface.
+type loggingNodeReader struct {
+ db *loggingTrieDb
+ codec uint64
+}
+
+// Node retrieves the trie node with the given node hash.
+func (reader *loggingNodeReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) {
+ blob, err := reader.NodeBlob(owner, path, hash)
+ if err != nil {
+ return nil, err
+ }
+ return decodeNodeUnsafe(hash[:], blob)
+}
+
+// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
+func (reader *loggingNodeReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) {
+ reader.db.getCount++
+ return reader.db.Node(hash, reader.codec)
+}
+
+func newLoggingStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, *loggingTrieDb, error) {
+ logdb := &loggingTrieDb{Database: db}
+ trie, err := New(id, logdb, codec)
+ if err != nil {
+ return nil, nil, err
+ }
+ return &StateTrie{trie: *trie, preimages: db.preimages}, logdb, nil
+}
+
+// makeLargeTestTrie create a sample test trie
+func makeLargeTestTrie(t testing.TB) (*Database, *StateTrie, *loggingTrieDb) {
+ // Create an empty trie
+ triedb := NewDatabase(rawdb.NewDatabase(memorydb.New()))
+ trie, logDb, err := newLoggingStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Fill it with some arbitrary data
+ for i := 0; i < 10000; i++ {
+ key := make([]byte, 32)
+ val := make([]byte, 32)
+ binary.BigEndian.PutUint64(key, uint64(i))
+ binary.BigEndian.PutUint64(val, uint64(i))
+ key = crypto.Keccak256(key)
+ val = crypto.Keccak256(val)
+ trie.Update(key, val)
+ }
+ _, nodes := trie.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ // Return the generated trie
+ return triedb, trie, logDb
+}
+
+// Tests that the node iterator indeed walks over the entire database contents.
+func TestNodeIteratorLargeTrie(t *testing.T) {
+ // Create some arbitrary test trie to iterate
+ _, trie, logDb := makeLargeTestTrie(t)
+ // Do a seek operation
+ trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
+ // master: 24 get operations
+ // this pr: 5 get operations
+ if have, want := logDb.getCount, uint64(5); have != want {
+ t.Fatalf("Wrong number of lookups during seek, have %d want %d", have, want)
+ }
+}
+
func TestIteratorNodeBlob(t *testing.T) {
+ // var (
+ // db = rawdb.NewMemoryDatabase()
+ // triedb = NewDatabase(db)
+ // trie = NewEmpty(triedb)
+ // )
+ // vals := []struct{ k, v string }{
+ // {"do", "verb"},
+ // {"ether", "wookiedoo"},
+ // {"horse", "stallion"},
+ // {"shaman", "horse"},
+ // {"doge", "coin"},
+ // {"dog", "puppy"},
+ // {"somethingveryoddindeedthis is", "myothernodedata"},
+ // }
+ // all := make(map[string]string)
+ // for _, val := range vals {
+ // all[val.k] = val.v
+ // trie.Update([]byte(val.k), []byte(val.v))
+ // }
+ // _, nodes := trie.Commit(false)
+ // triedb.Update(NewWithNodeSet(nodes))
+
edb := rawdb.NewMemoryDatabase()
db := geth_trie.NewDatabase(edb)
orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase()))
- if _, err := updateTrie(orig, testdata1); err != nil {
+ if _, err := updateTrie(orig, packableTestData); err != nil {
t.Fatal(err)
}
root := commitTrie(t, db, orig)
@@ -167,7 +608,7 @@ func TestIteratorNodeBlob(t *testing.T) {
for dbIter.Next() {
got, present := found[common.BytesToHash(dbIter.Key())]
if !present {
- t.Fatalf("Missing trie node %v", dbIter.Key())
+ t.Fatalf("Miss trie node %v", dbIter.Key())
}
if !bytes.Equal(got, dbIter.Value()) {
t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
@@ -175,6 +616,6 @@ func TestIteratorNodeBlob(t *testing.T) {
count += 1
}
if count != len(found) {
- t.Fatal("Find extra trie node via iterator")
+ t.Fatalf("Wrong number of trie nodes found, want %d, got %d", len(found), count)
}
}
diff --git a/trie_by_cid/trie/node.go b/trie_by_cid/trie/node.go
index c995099..6ce6551 100644
--- a/trie_by_cid/trie/node.go
+++ b/trie_by_cid/trie/node.go
@@ -23,7 +23,6 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/trie"
)
var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
@@ -157,7 +156,7 @@ func decodeShort(hash, elems []byte) (node, error) {
return nil, err
}
flag := nodeFlag{hash: hash}
- key := trie.CompactToHex(kbuf)
+ key := compactToHex(kbuf)
if hasTerm(key) {
// value node
val, _, err := rlp.SplitString(rest)
diff --git a/trie_by_cid/trie/node_enc.go b/trie_by_cid/trie/node_enc.go
index 2d26350..cade35b 100644
--- a/trie_by_cid/trie/node_enc.go
+++ b/trie_by_cid/trie/node_enc.go
@@ -58,3 +58,30 @@ func (n hashNode) encode(w rlp.EncoderBuffer) {
func (n valueNode) encode(w rlp.EncoderBuffer) {
w.WriteBytes(n)
}
+
+func (n rawFullNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+ for _, c := range n {
+ if c != nil {
+ c.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+ }
+ w.ListEnd(offset)
+}
+
+func (n *rawShortNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+ w.WriteBytes(n.Key)
+ if n.Val != nil {
+ n.Val.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+ w.ListEnd(offset)
+}
+
+func (n rawNode) encode(w rlp.EncoderBuffer) {
+ w.Write(n)
+}
diff --git a/trie_by_cid/trie/node_test.go b/trie_by_cid/trie/node_test.go
index ac1d8fb..9b8b337 100644
--- a/trie_by_cid/trie/node_test.go
+++ b/trie_by_cid/trie/node_test.go
@@ -20,6 +20,7 @@ import (
"bytes"
"testing"
+ "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
)
@@ -92,3 +93,123 @@ func TestDecodeFullNode(t *testing.T) {
t.Fatalf("decode full node err: %v", err)
}
}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkEncodeShortNode
+// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op
+func BenchmarkEncodeShortNode(b *testing.B) {
+ node := &shortNode{
+ Key: []byte{0x1, 0x2},
+ Val: hashNode(randBytes(32)),
+ }
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ nodeToBytes(node)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkEncodeFullNode
+// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op
+func BenchmarkEncodeFullNode(b *testing.B) {
+ node := &fullNode{}
+ for i := 0; i < 16; i++ {
+ node.Children[i] = hashNode(randBytes(32))
+ }
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ nodeToBytes(node)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeShortNode
+// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op
+func BenchmarkDecodeShortNode(b *testing.B) {
+ node := &shortNode{
+ Key: []byte{0x1, 0x2},
+ Val: hashNode(randBytes(32)),
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNode(hash, blob)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeShortNodeUnsafe
+// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op
+func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
+ node := &shortNode{
+ Key: []byte{0x1, 0x2},
+ Val: hashNode(randBytes(32)),
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNodeUnsafe(hash, blob)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeFullNode
+// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op
+func BenchmarkDecodeFullNode(b *testing.B) {
+ node := &fullNode{}
+ for i := 0; i < 16; i++ {
+ node.Children[i] = hashNode(randBytes(32))
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNode(hash, blob)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeFullNodeUnsafe
+// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op
+func BenchmarkDecodeFullNodeUnsafe(b *testing.B) {
+ node := &fullNode{}
+ for i := 0; i < 16; i++ {
+ node.Children[i] = hashNode(randBytes(32))
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNodeUnsafe(hash, blob)
+ }
+}
diff --git a/trie_by_cid/trie/nodeset.go b/trie_by_cid/trie/nodeset.go
new file mode 100644
index 0000000..99e4a80
--- /dev/null
+++ b/trie_by_cid/trie/nodeset.go
@@ -0,0 +1,218 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+ "reflect"
+ "sort"
+ "strings"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// memoryNode is all the information we know about a single cached trie node
+// in the memory.
+type memoryNode struct {
+ hash common.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes
+ size uint16 // Byte size of the useful cached data, 0 for deleted nodes
+ node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes
+}
+
+// memoryNodeSize is the raw size of a memoryNode data structure without any
+// node data included. It's an approximate size, but should be a lot better
+// than not counting them.
+// nolint:unused
+var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size())
+
+// memorySize returns the total memory size used by this node.
+// nolint:unused
+func (n *memoryNode) memorySize(pathlen int) int {
+ return int(n.size) + memoryNodeSize + pathlen
+}
+
+// rlp returns the raw rlp encoded blob of the cached trie node, either directly
+// from the cache, or by regenerating it from the collapsed node.
+// nolint:unused
+func (n *memoryNode) rlp() []byte {
+ if node, ok := n.node.(rawNode); ok {
+ return node
+ }
+ return nodeToBytes(n.node)
+}
+
+// obj returns the decoded and expanded trie node, either directly from the cache,
+// or by regenerating it from the rlp encoded blob.
+// nolint:unused
+func (n *memoryNode) obj() node {
+ if node, ok := n.node.(rawNode); ok {
+ return mustDecodeNode(n.hash[:], node)
+ }
+ return expandNode(n.hash[:], n.node)
+}
+
+// isDeleted returns the indicator if the node is marked as deleted.
+func (n *memoryNode) isDeleted() bool {
+ return n.hash == (common.Hash{})
+}
+
+// nodeWithPrev wraps the memoryNode with the previous node value.
+// nolint: unused
+type nodeWithPrev struct {
+ *memoryNode
+ prev []byte // RLP-encoded previous value, nil means it's non-existent
+}
+
+// unwrap returns the internal memoryNode object.
+// nolint:unused
+func (n *nodeWithPrev) unwrap() *memoryNode {
+ return n.memoryNode
+}
+
+// memorySize returns the total memory size used by this node. It overloads
+// the function in memoryNode by counting the size of previous value as well.
+// nolint: unused
+func (n *nodeWithPrev) memorySize(pathlen int) int {
+ return n.memoryNode.memorySize(pathlen) + len(n.prev)
+}
+
+// NodeSet contains all dirty nodes collected during the commit operation.
+// Each node is keyed by path. It's not thread-safe to use.
+type NodeSet struct {
+ owner common.Hash // the identifier of the trie
+ nodes map[string]*memoryNode // the set of dirty nodes(inserted, updated, deleted)
+ leaves []*leaf // the list of dirty leaves
+ updates int // the count of updated and inserted nodes
+ deletes int // the count of deleted nodes
+
+ // The list of accessed nodes, which records the original node value.
+ // The origin value is expected to be nil for newly inserted node
+ // and is expected to be non-nil for other types(updated, deleted).
+ accessList map[string][]byte
+}
+
+// NewNodeSet initializes an empty node set to be used for tracking dirty nodes
+// from a specific account or storage trie. The owner is zero for the account
+// trie and the owning account address hash for storage tries.
+func NewNodeSet(owner common.Hash, accessList map[string][]byte) *NodeSet {
+ return &NodeSet{
+ owner: owner,
+ nodes: make(map[string]*memoryNode),
+ accessList: accessList,
+ }
+}
+
+// forEachWithOrder iterates the dirty nodes with the order from bottom to top,
+// right to left, nodes with the longest path will be iterated first.
+func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) {
+ var paths sort.StringSlice
+ for path := range set.nodes {
+ paths = append(paths, path)
+ }
+ // Bottom-up, longest path first
+ sort.Sort(sort.Reverse(paths))
+ for _, path := range paths {
+ callback(path, set.nodes[path])
+ }
+}
+
+// markUpdated marks the node as dirty(newly-inserted or updated).
+func (set *NodeSet) markUpdated(path []byte, node *memoryNode) {
+ set.nodes[string(path)] = node
+ set.updates += 1
+}
+
+// markDeleted marks the node as deleted.
+func (set *NodeSet) markDeleted(path []byte) {
+ set.nodes[string(path)] = &memoryNode{}
+ set.deletes += 1
+}
+
+// addLeaf collects the provided leaf node into set.
+func (set *NodeSet) addLeaf(node *leaf) {
+ set.leaves = append(set.leaves, node)
+}
+
+// Size returns the number of dirty nodes in set.
+func (set *NodeSet) Size() (int, int) {
+ return set.updates, set.deletes
+}
+
+// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can
+// we get rid of it?
+func (set *NodeSet) Hashes() []common.Hash {
+ var ret []common.Hash
+ for _, node := range set.nodes {
+ ret = append(ret, node.hash)
+ }
+ return ret
+}
+
+// Summary returns a string-representation of the NodeSet.
+func (set *NodeSet) Summary() string {
+ var out = new(strings.Builder)
+ fmt.Fprintf(out, "nodeset owner: %v\n", set.owner)
+ if set.nodes != nil {
+ for path, n := range set.nodes {
+ // Deletion
+ if n.isDeleted() {
+ fmt.Fprintf(out, " [-]: %x prev: %x\n", path, set.accessList[path])
+ continue
+ }
+ // Insertion
+ origin, ok := set.accessList[path]
+ if !ok {
+ fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash)
+ continue
+ }
+ // Update
+ fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, origin)
+ }
+ }
+ for _, n := range set.leaves {
+ fmt.Fprintf(out, "[leaf]: %v\n", n)
+ }
+ return out.String()
+}
+
+// MergedNodeSet represents a merged dirty node set for a group of tries.
+type MergedNodeSet struct {
+ sets map[common.Hash]*NodeSet
+}
+
+// NewMergedNodeSet initializes an empty merged set.
+func NewMergedNodeSet() *MergedNodeSet {
+ return &MergedNodeSet{sets: make(map[common.Hash]*NodeSet)}
+}
+
+// NewWithNodeSet constructs a merged nodeset with the provided single set.
+func NewWithNodeSet(set *NodeSet) *MergedNodeSet {
+ merged := NewMergedNodeSet()
+ merged.Merge(set)
+ return merged
+}
+
+// Merge merges the provided dirty nodes of a trie into the set. The assumption
+// is held that no duplicated set belonging to the same trie will be merged twice.
+func (set *MergedNodeSet) Merge(other *NodeSet) error {
+ _, present := set.sets[other.owner]
+ if present {
+ return fmt.Errorf("duplicate trie for owner %#x", other.owner)
+ }
+ set.sets[other.owner] = other
+ return nil
+}
diff --git a/trie_by_cid/trie/preimages.go b/trie_by_cid/trie/preimages.go
new file mode 100644
index 0000000..a6359ca
--- /dev/null
+++ b/trie_by_cid/trie/preimages.go
@@ -0,0 +1,78 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "sync"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/ethdb"
+)
+
+// preimageStore is the store for caching preimages of node key.
+type preimageStore struct {
+ lock sync.RWMutex
+ disk ethdb.KeyValueStore
+ preimages map[common.Hash][]byte // Preimages of nodes from the secure trie
+ preimagesSize common.StorageSize // Storage size of the preimages cache
+}
+
+// newPreimageStore initializes the store for caching preimages.
+func newPreimageStore(disk ethdb.KeyValueStore) *preimageStore {
+ return &preimageStore{
+ disk: disk,
+ preimages: make(map[common.Hash][]byte),
+ }
+}
+
+// insertPreimage writes a new trie node pre-image to the memory database if it's
+// yet unknown. The method will NOT make a copy of the slice, only use if the
+// preimage will NOT be changed later on.
+func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) {
+ store.lock.Lock()
+ defer store.lock.Unlock()
+
+ for hash, preimage := range preimages {
+ if _, ok := store.preimages[hash]; ok {
+ continue
+ }
+ store.preimages[hash] = preimage
+ store.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
+ }
+}
+
+// preimage retrieves a cached trie node pre-image from memory. If it cannot be
+// found cached, the method queries the persistent database for the content.
+func (store *preimageStore) preimage(hash common.Hash) []byte {
+ store.lock.RLock()
+ preimage := store.preimages[hash]
+ store.lock.RUnlock()
+
+ if preimage != nil {
+ return preimage
+ }
+ return rawdb.ReadPreimage(store.disk, hash)
+}
+
+// size returns the current storage size of accumulated preimages.
+func (store *preimageStore) size() common.StorageSize {
+ store.lock.RLock()
+ defer store.lock.RUnlock()
+
+ return store.preimagesSize
+}
diff --git a/trie_by_cid/trie/proof.go b/trie_by_cid/trie/proof.go
index e48eda5..7315c0d 100644
--- a/trie_by_cid/trie/proof.go
+++ b/trie_by_cid/trie/proof.go
@@ -20,9 +20,10 @@ import (
"bytes"
"fmt"
+ "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
- "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/trie"
+ log "github.com/sirupsen/logrus"
)
var VerifyProof = trie.VerifyProof
@@ -61,10 +62,15 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
key = key[1:]
nodes = append(nodes, n)
case hashNode:
+ // Retrieve the specified node from the underlying node reader.
+ // trie.resolveAndTrack is not used since in that function the
+ // loaded blob will be tracked, while it's not required here since
+ // all loaded nodes won't be linked to trie at all and track nodes
+ // may lead to out-of-memory issue.
var err error
- tn, err = t.resolveHash(n, prefix)
+ tn, err = t.reader.node(prefix, common.BytesToHash(n))
if err != nil {
- log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
+ log.Error("Unhandled trie error in Trie.Prove", "err", err)
return err
}
default:
diff --git a/trie_by_cid/trie/proof_test.go b/trie_by_cid/trie/proof_test.go
index 3fd508a..6b23bcd 100644
--- a/trie_by_cid/trie/proof_test.go
+++ b/trie_by_cid/trie/proof_test.go
@@ -14,29 +14,39 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package trie_test
+package trie
import (
"bytes"
+ crand "crypto/rand"
+ "encoding/binary"
+ "fmt"
mrand "math/rand"
"sort"
"testing"
- "time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
- geth_trie "github.com/ethereum/go-ethereum/trie"
-
- . "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
)
-// tweakable trie size
-var scaleFactor = 512
+// Prng is a pseudo random number generator seeded by strong randomness.
+// The randomness is printed on startup in order to make failures reproducible.
+var prng = initRnd()
-func init() {
- mrand.Seed(time.Now().UnixNano())
+func initRnd() *mrand.Rand {
+ var seed [8]byte
+ crand.Read(seed[:])
+ rnd := mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(seed[:]))))
+ fmt.Printf("Seed: %x\n", seed)
+ return rnd
+}
+
+func randBytes(n int) []byte {
+ r := make([]byte, n)
+ prng.Read(r)
+ return r
}
// makeProvers creates Merkle trie provers based on different implementations to
@@ -64,7 +74,7 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
}
func TestProof(t *testing.T) {
- trie, vals := randomTrie(t, scaleFactor)
+ trie, vals := randomTrie(500)
root := trie.Hash()
for i, prover := range makeProvers(trie) {
for _, kv := range vals {
@@ -72,11 +82,11 @@ func TestProof(t *testing.T) {
if proof == nil {
t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
}
- val, err := geth_trie.VerifyProof(root, kv.k, proof)
+ val, err := VerifyProof(root, kv.k, proof)
if err != nil {
t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof)
}
- if kv.v != unpackValue(val) {
+ if !bytes.Equal(val, kv.v) {
t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v)
}
}
@@ -84,13 +94,8 @@ func TestProof(t *testing.T) {
}
func TestOneElementProof(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
- orig.Update([]byte("k"), packValue(42))
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
-
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ updateString(trie, "k", "v")
for i, prover := range makeProvers(trie) {
proof := prover([]byte("k"))
if proof == nil {
@@ -99,18 +104,18 @@ func TestOneElementProof(t *testing.T) {
if proof.Len() != 1 {
t.Errorf("prover %d: proof should have one element", i)
}
- val, err := geth_trie.VerifyProof(trie.Hash(), []byte("k"), proof)
+ val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
if err != nil {
t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
}
- if 42 != unpackValue(val) {
+ if !bytes.Equal(val, []byte("v")) {
t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val)
}
}
}
func TestBadProof(t *testing.T) {
- trie, vals := randomTrie(t, 2*scaleFactor)
+ trie, vals := randomTrie(800)
root := trie.Hash()
for i, prover := range makeProvers(trie) {
for _, kv := range vals {
@@ -140,12 +145,8 @@ func TestBadProof(t *testing.T) {
// Tests that missing keys can also be proven. The test explicitly uses a single
// entry trie and checks for missing keys both before and after the single entry.
func TestMissingKeyProof(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
- orig.Update([]byte("k"), packValue(42))
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ updateString(trie, "k", "v")
for i, key := range []string{"a", "j", "l", "z"} {
proof := memorydb.New()
@@ -164,15 +165,7 @@ func TestMissingKeyProof(t *testing.T) {
}
}
-type entry struct {
- k, v []byte
-}
-
-func packEntry(kv *kv) *entry {
- return &entry{kv.k, packValue(kv.v)}
-}
-
-type entrySlice []*entry
+type entrySlice []*kv
func (p entrySlice) Len() int { return len(p) }
func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
@@ -181,10 +174,10 @@ func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// TestRangeProof tests normal range proof with both edge proofs
// as the existent proof. The test cases are generated randomly.
func TestRangeProof(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
for i := 0; i < 500; i++ {
@@ -214,10 +207,10 @@ func TestRangeProof(t *testing.T) {
// TestRangeProof tests normal range proof with two non-existent proofs.
// The test cases are generated randomly.
func TestRangeProofWithNonExistentProof(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
for i := 0; i < 500; i++ {
@@ -286,10 +279,10 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
// - There exists a gap between the first element and the left edge proof
// - There exists a gap between the last element and the right edge proof
func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -343,10 +336,10 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
// element. The first edge proof can be existent one or
// non-existent one.
func TestOneElementRangeProof(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -408,38 +401,32 @@ func TestOneElementRangeProof(t *testing.T) {
}
// Test the mini trie with only a single element.
- t.Run("single element", func(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
- entry := &entry{randBytes(32), packValue(mrand.Int63())}
- orig.Update(entry.k, entry.v)
- root := commitTrie(t, db, orig)
- tinyTrie := indexTrie(t, edb, root)
+ tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ entry := &kv{randBytes(32), randBytes(20), false}
+ tinyTrie.Update(entry.k, entry.v)
- first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
- last = entry.k
- proof = memorydb.New()
- if err := tinyTrie.Prove(first, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- if err := tinyTrie.Prove(last, 0, proof); err != nil {
- t.Fatalf("Failed to prove the last node %v", err)
- }
- _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
- if err != nil {
- t.Fatalf("Expected no error, got %v", err)
- }
- })
+ first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
+ last = entry.k
+ proof = memorydb.New()
+ if err := tinyTrie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := tinyTrie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
}
// TestAllElementsProof tests the range proof with all elements.
// The edge proofs can be nil.
func TestAllElementsProof(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -483,86 +470,73 @@ func TestAllElementsProof(t *testing.T) {
}
}
-func positionCases(size int) []int {
- var cases []int
- for _, pos := range []int{0, 1, 50, 100, 1000, 2000, size - 1} {
- if pos >= size {
- break
- }
- cases = append(cases, pos)
- }
- return cases
-}
-
// TestSingleSideRangeProof tests the range starts from zero.
func TestSingleSideRangeProof(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
- var entries entrySlice
- for i := 0; i < 8*scaleFactor; i++ {
- value := &entry{randBytes(32), packValue(mrand.Int63())}
- orig.Update(value.k, value.v)
- entries = append(entries, value)
- }
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
- sort.Sort(entries)
+ for i := 0; i < 64; i++ {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ var entries entrySlice
+ for i := 0; i < 4096; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ entries = append(entries, value)
+ }
+ sort.Sort(entries)
- for _, pos := range positionCases(len(entries)) {
- proof := memorydb.New()
- if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- k := make([][]byte, 0)
- v := make([][]byte, 0)
- for i := 0; i <= pos; i++ {
- k = append(k, entries[i].k)
- v = append(v, entries[i].v)
- }
- _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
- if err != nil {
- t.Fatalf("Expected no error, got %v", err)
+ var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
+ for _, pos := range cases {
+ proof := memorydb.New()
+ if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ k := make([][]byte, 0)
+ v := make([][]byte, 0)
+ for i := 0; i <= pos; i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
}
}
}
// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
func TestReverseSingleSideRangeProof(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
- var entries entrySlice
- for i := 0; i < 8*scaleFactor; i++ {
- value := &entry{randBytes(32), packValue(mrand.Int63())}
- orig.Update(value.k, value.v)
- entries = append(entries, value)
- }
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
- sort.Sort(entries)
+ for i := 0; i < 64; i++ {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ var entries entrySlice
+ for i := 0; i < 4096; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ entries = append(entries, value)
+ }
+ sort.Sort(entries)
- for _, pos := range positionCases(len(entries)) {
- proof := memorydb.New()
- if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
- t.Fatalf("Failed to prove the first node %v", err)
- }
- last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
- if err := trie.Prove(last.Bytes(), 0, proof); err != nil {
- t.Fatalf("Failed to prove the last node %v", err)
- }
- k := make([][]byte, 0)
- v := make([][]byte, 0)
- for i := pos; i < len(entries); i++ {
- k = append(k, entries[i].k)
- v = append(v, entries[i].v)
- }
- _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
- if err != nil {
- t.Fatalf("Expected no error, got %v", err)
+ var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
+ for _, pos := range cases {
+ proof := memorydb.New()
+ if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
+ if err := trie.Prove(last.Bytes(), 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ k := make([][]byte, 0)
+ v := make([][]byte, 0)
+ for i := pos; i < len(entries); i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
}
}
}
@@ -570,10 +544,10 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
// TestBadRangeProof tests a few cases which the proof is wrong.
// The prover is expected to detect the error.
func TestBadRangeProof(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -641,17 +615,13 @@ func TestBadRangeProof(t *testing.T) {
// TestGappedRangeProof focuses on the small trie with embedded nodes.
// If the gapped node is embedded in the trie, it should be detected too.
func TestGappedRangeProof(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
- var entries entrySlice
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ var entries []*kv // Sorted entries
for i := byte(0); i < 10; i++ {
- value := &entry{common.LeftPadBytes([]byte{i}, 32), packValue(int64(i))}
- orig.Update(value.k, value.v)
+ value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
entries = append(entries, value)
}
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
first, last := 2, 8
proof := memorydb.New()
if err := trie.Prove(entries[first].k, 0, proof); err != nil {
@@ -677,10 +647,10 @@ func TestGappedRangeProof(t *testing.T) {
// TestSameSideProofs tests the element is not in the range covered by proofs
func TestSameSideProofs(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -719,17 +689,13 @@ func TestSameSideProofs(t *testing.T) {
}
func TestHasRightElement(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
var entries entrySlice
- for i := 0; i < 8*scaleFactor; i++ {
- value := &entry{randBytes(32), packValue(int64(i))}
- orig.Update(value.k, value.v)
+ for i := 0; i < 4096; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
entries = append(entries, value)
}
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
sort.Sort(entries)
var cases = []struct {
@@ -797,10 +763,10 @@ func TestHasRightElement(t *testing.T) {
// TestEmptyRangeProof tests the range proof with "no" element.
// The first edge proof must be a non-existent proof.
func TestEmptyRangeProof(t *testing.T) {
- trie, vals := randomTrie(t, 8*scaleFactor)
+ trie, vals := randomTrie(4096)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -827,14 +793,49 @@ func TestEmptyRangeProof(t *testing.T) {
}
}
+// TestBloatedProof tests a malicious proof, where the proof is more or less the
+// whole trie. Previously we didn't accept such packets, but the new APIs do, so
+// lets leave this test as a bit weird, but present.
+func TestBloatedProof(t *testing.T) {
+ // Use a small trie
+ trie, kvs := nonRandomTrie(100)
+ var entries entrySlice
+ for _, kv := range kvs {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+ var keys [][]byte
+ var vals [][]byte
+
+ proof := memorydb.New()
+ // In the 'malicious' case, we add proofs for every single item
+ // (but only one key/value pair used as leaf)
+ for i, entry := range entries {
+ trie.Prove(entry.k, 0, proof)
+ if i == 50 {
+ keys = append(keys, entry.k)
+ vals = append(vals, entry.v)
+ }
+ }
+ // For reference, we use the same function, but _only_ prove the first
+ // and last element
+ want := memorydb.New()
+ trie.Prove(keys[0], 0, want)
+ trie.Prove(keys[len(keys)-1], 0, want)
+
+ if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil {
+ t.Fatalf("expected bloated proof to succeed, got %v", err)
+ }
+}
+
// TestEmptyValueRangeProof tests normal range proof with both edge proofs
// as the existent proof, but with an extra empty value included, which is a
// noop technically, but practically should be rejected.
func TestEmptyValueRangeProof(t *testing.T) {
- trie, values := randomTrie(t, scaleFactor)
+ trie, values := randomTrie(512)
var entries entrySlice
for _, kv := range values {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -847,8 +848,8 @@ func TestEmptyValueRangeProof(t *testing.T) {
break
}
}
- noop := &entry{key, []byte{}}
- entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...)
+ noop := &kv{key, []byte{}, false}
+ entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
start, end := 1, len(entries)-1
@@ -875,10 +876,10 @@ func TestEmptyValueRangeProof(t *testing.T) {
// but with an extra empty value included, which is a noop technically, but
// practically should be rejected.
func TestAllElementsEmptyValueRangeProof(t *testing.T) {
- trie, values := randomTrie(t, scaleFactor)
+ trie, values := randomTrie(512)
var entries entrySlice
for _, kv := range values {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -891,8 +892,8 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) {
break
}
}
- noop := &entry{key, nil}
- entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...)
+ noop := &kv{key, []byte{}, false}
+ entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
var keys [][]byte
var vals [][]byte
@@ -938,7 +939,7 @@ func decreaseKey(key []byte) []byte {
}
func BenchmarkProve(b *testing.B) {
- trie, vals := randomTrie(b, 100)
+ trie, vals := randomTrie(100)
var keys []string
for k := range vals {
keys = append(keys, k)
@@ -955,7 +956,7 @@ func BenchmarkProve(b *testing.B) {
}
func BenchmarkVerifyProof(b *testing.B) {
- trie, vals := randomTrie(b, 100)
+ trie, vals := randomTrie(100)
root := trie.Hash()
var keys []string
var proofs []*memorydb.Database
@@ -981,10 +982,10 @@ func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b,
func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) }
func benchmarkVerifyRangeProof(b *testing.B, size int) {
- trie, vals := randomTrie(b, 8192)
+ trie, vals := randomTrie(8192)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -1018,10 +1019,10 @@ func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof
func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) }
func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
- trie, vals := randomTrie(b, size)
+ trie, vals := randomTrie(size)
var entries entrySlice
for _, kv := range vals {
- entries = append(entries, packEntry(kv))
+ entries = append(entries, kv)
}
sort.Sort(entries)
@@ -1040,24 +1041,56 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
}
}
+func randomTrie(n int) (*Trie, map[string]*kv) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ vals := make(map[string]*kv)
+ for i := byte(0); i < 100; i++ {
+ value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ trie.Update(value2.k, value2.v)
+ vals[string(value.k)] = value
+ vals[string(value2.k)] = value2
+ }
+ for i := 0; i < n; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ vals[string(value.k)] = value
+ }
+ return trie, vals
+}
+
+func nonRandomTrie(n int) (*Trie, map[string]*kv) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ vals := make(map[string]*kv)
+ max := uint64(0xffffffffffffffff)
+ for i := uint64(0); i < uint64(n); i++ {
+ value := make([]byte, 32)
+ key := make([]byte, 32)
+ binary.LittleEndian.PutUint64(key, i)
+ binary.LittleEndian.PutUint64(value, i-max)
+ //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ elem := &kv{key, value, false}
+ trie.Update(elem.k, elem.v)
+ vals[string(elem.k)] = elem
+ }
+ return trie, vals
+}
+
func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
keys := [][]byte{
common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"),
common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"),
}
vals := [][]byte{
- packValue(2),
- packValue(3),
+ common.Hex2Bytes("02"),
+ common.Hex2Bytes("03"),
}
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig := geth_trie.NewEmpty(db)
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for i, key := range keys {
- orig.Update(key, vals[i])
+ trie.Update(key, vals[i])
}
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
-
+ root := trie.Hash()
proof := memorydb.New()
start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000")
end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
diff --git a/trie_by_cid/trie/secure_trie.go b/trie_by_cid/trie/secure_trie.go
index 6e4b4ca..25f1fd7 100644
--- a/trie_by_cid/trie/secure_trie.go
+++ b/trie_by_cid/trie/secure_trie.go
@@ -17,82 +17,88 @@
package trie
import (
- "fmt"
-
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
- "github.com/ethereum/go-ethereum/statediff/indexer/ipld"
+ log "github.com/sirupsen/logrus"
)
-// StateTrie wraps a trie with key hashing. In a secure trie, all
+// StateTrie wraps a trie with key hashing. In a stateTrie trie, all
// access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that
// increase the access time.
//
// Contrary to a regular trie, a StateTrie can only be created with
-// New and must have an attached database.
+// New and must have an attached database. The database also stores
+// the preimage of each key if preimage recording is enabled.
//
// StateTrie is not safe for concurrent use.
type StateTrie struct {
- trie Trie
- hashKeyBuf [common.HashLength]byte
+ trie Trie
+ preimages *preimageStore
+ hashKeyBuf [common.HashLength]byte
+ secKeyCache map[string][]byte
+ secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch
}
-// NewStateTrie creates a trie with an existing root node from a backing database
-// and optional intermediate in-memory node pool.
+// NewStateTrie creates a trie with an existing root node from a backing database.
//
// If root is the zero hash or the sha3 hash of an empty string, the
// trie is initially empty. Otherwise, New will panic if db is nil
// and returns MissingNodeError if the root node cannot be found.
-//
-// Accessing the trie loads nodes from the database or node pool on demand.
-// Loaded nodes are kept around until their 'cache generation' expires.
-// A new cache generation is created by each call to Commit.
-// cachelimit sets the number of past cache generations to keep.
-//
-// Retrieves IPLD blocks by CID encoded as "eth-state-trie"
-func NewStateTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
- return newStateTrie(owner, root, db, ipld.MEthStateTrie)
-}
-
-// NewStorageTrie is identical to NewStateTrie, but retrieves IPLD blocks encoded
-// as "eth-storage-trie"
-func NewStorageTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
- return newStateTrie(owner, root, db, ipld.MEthStorageTrie)
-}
-
-func newStateTrie(owner common.Hash, root common.Hash, db *Database, codec uint64) (*StateTrie, error) {
+func NewStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, error) {
if db == nil {
- panic("NewStateTrie called without a database")
+ panic("trie.NewStateTrie called without a database")
}
- trie, err := New(owner, root, db, codec)
+ trie, err := New(id, db, codec)
if err != nil {
return nil, err
}
- return &StateTrie{trie: *trie}, nil
+ return &StateTrie{trie: *trie, preimages: db.preimages}, nil
+}
+
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *StateTrie) Get(key []byte) []byte {
+ res, err := t.TryGet(key)
+ if err != nil {
+ log.Error("Unhandled trie error in StateTrie.Get", "err", err)
+ }
+ return res
}
// TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
-// If a node was not found in the database, a MissingNodeError is returned.
+// If the specified node is not in the trie, nil will be returned.
+// If a trie node is not found in the database, a MissingNodeError is returned.
func (t *StateTrie) TryGet(key []byte) ([]byte, error) {
return t.trie.TryGet(t.hashKey(key))
}
-func (t *StateTrie) TryGetAccount(key []byte) (*types.StateAccount, error) {
- var ret types.StateAccount
- res, err := t.TryGet(key)
- if err != nil {
- // log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
- panic(fmt.Sprintf("Unhandled trie error: %v", err))
- return &ret, err
+// TryGetAccount attempts to retrieve an account with provided account address.
+// If the specified account is not in the trie, nil will be returned.
+// If a trie node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) {
+ res, err := t.trie.TryGet(t.hashKey(address.Bytes()))
+ if res == nil || err != nil {
+ return nil, err
}
- if res == nil {
- return nil, nil
+ ret := new(types.StateAccount)
+ err = rlp.DecodeBytes(res, ret)
+ return ret, err
+}
+
+// TryGetAccountByHash does the same thing as TryGetAccount, however
+// it expects an account hash that is the hash of address. This constitutes an
+// abstraction leak, since the client code needs to know the key format.
+func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) {
+ res, err := t.trie.TryGet(addrHash.Bytes())
+ if res == nil || err != nil {
+ return nil, err
}
- err = rlp.DecodeBytes(res, &ret)
- return &ret, err
+ ret := new(types.StateAccount)
+ err = rlp.DecodeBytes(res, ret)
+ return ret, err
}
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
@@ -103,12 +109,124 @@ func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) {
return t.trie.TryGetNode(path)
}
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *StateTrie) Update(key, value []byte) {
+ if err := t.TryUpdate(key, value); err != nil {
+ log.Error("Unhandled trie error in StateTrie.Update", "err", err)
+ }
+}
+
+// TryUpdate associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+//
+// If a node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryUpdate(key, value []byte) error {
+ hk := t.hashKey(key)
+ err := t.trie.TryUpdate(hk, value)
+ if err != nil {
+ return err
+ }
+ t.getSecKeyCache()[string(hk)] = common.CopyBytes(key)
+ return nil
+}
+
+// TryUpdateAccount account will abstract the write of an account to the
+// secure trie.
+func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error {
+ hk := t.hashKey(address.Bytes())
+ data, err := rlp.EncodeToBytes(acc)
+ if err != nil {
+ return err
+ }
+ if err := t.trie.TryUpdate(hk, data); err != nil {
+ return err
+ }
+ t.getSecKeyCache()[string(hk)] = address.Bytes()
+ return nil
+}
+
+// Delete removes any existing value for key from the trie.
+func (t *StateTrie) Delete(key []byte) {
+ if err := t.TryDelete(key); err != nil {
+ log.Error("Unhandled trie error in StateTrie.Delete", "err", err)
+ }
+}
+
+// TryDelete removes any existing value for key from the trie.
+// If the specified trie node is not in the trie, nothing will be changed.
+// If a node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryDelete(key []byte) error {
+ hk := t.hashKey(key)
+ delete(t.getSecKeyCache(), string(hk))
+ return t.trie.TryDelete(hk)
+}
+
+// TryDeleteAccount abstracts an account deletion from the trie.
+func (t *StateTrie) TryDeleteAccount(address common.Address) error {
+ hk := t.hashKey(address.Bytes())
+ delete(t.getSecKeyCache(), string(hk))
+ return t.trie.TryDelete(hk)
+}
+
+// GetKey returns the sha3 preimage of a hashed key that was
+// previously used to store a value.
+func (t *StateTrie) GetKey(shaKey []byte) []byte {
+ if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
+ return key
+ }
+ if t.preimages == nil {
+ return nil
+ }
+ return t.preimages.preimage(common.BytesToHash(shaKey))
+}
+
+// Commit collects all dirty nodes in the trie and replaces them with the
+// corresponding node hash. All collected nodes (including dirty leaves if
+// collectLeaf is true) will be encapsulated into a nodeset for return.
+// The returned nodeset can be nil if the trie is clean (nothing to commit).
+// All cached preimages will be also flushed if preimages recording is enabled.
+// Once the trie is committed, it's not usable anymore. A new trie must
+// be created with new root and updated trie database for following usage
+func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
+ // Write all the pre-images to the actual disk database
+ if len(t.getSecKeyCache()) > 0 {
+ if t.preimages != nil {
+ preimages := make(map[common.Hash][]byte)
+ for hk, key := range t.secKeyCache {
+ preimages[common.BytesToHash([]byte(hk))] = key
+ }
+ t.preimages.insertPreimage(preimages)
+ }
+ t.secKeyCache = make(map[string][]byte)
+ }
+ // Commit the trie and return its modified nodeset.
+ return t.trie.Commit(collectLeaf)
+}
+
// Hash returns the root hash of StateTrie. It does not write to the
// database and can be used even if the trie doesn't have one.
func (t *StateTrie) Hash() common.Hash {
return t.trie.Hash()
}
+// Copy returns a copy of StateTrie.
+func (t *StateTrie) Copy() *StateTrie {
+ return &StateTrie{
+ trie: *t.trie.Copy(),
+ preimages: t.preimages,
+ secKeyCache: t.secKeyCache,
+ }
+}
+
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key.
func (t *StateTrie) NodeIterator(start []byte) NodeIterator {
@@ -126,3 +244,14 @@ func (t *StateTrie) hashKey(key []byte) []byte {
returnHasherToPool(h)
return t.hashKeyBuf[:]
}
+
+// getSecKeyCache returns the current secure key cache, creating a new one if
+// ownership changed (i.e. the current secure trie is a copy of another owning
+// the actual cache).
+func (t *StateTrie) getSecKeyCache() map[string][]byte {
+ if t != t.secKeyCacheOwner {
+ t.secKeyCacheOwner = t
+ t.secKeyCache = make(map[string][]byte)
+ }
+ return t.secKeyCache
+}
diff --git a/trie_by_cid/trie/secure_trie_test.go b/trie_by_cid/trie/secure_trie_test.go
new file mode 100644
index 0000000..41edbc3
--- /dev/null
+++ b/trie_by_cid/trie/secure_trie_test.go
@@ -0,0 +1,148 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "fmt"
+ "runtime"
+ "sync"
+ "testing"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/ethereum/go-ethereum/core/rawdb"
+ "github.com/ethereum/go-ethereum/crypto"
+)
+
+// makeTestStateTrie creates a large enough secure trie for testing.
+func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
+ // Create an empty trie
+ triedb := NewDatabase(rawdb.NewMemoryDatabase())
+ trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
+
+ // Fill it with some arbitrary data
+ content := make(map[string][]byte)
+ for i := byte(0); i < 255; i++ {
+ // Map the same data under multiple keys
+ key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ // Add some other data to inflate the trie
+ for j := byte(3); j < 13; j++ {
+ key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
+ content[string(key)] = val
+ trie.Update(key, val)
+ }
+ }
+ root, nodes := trie.Commit(false)
+ if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
+ panic(fmt.Errorf("failed to commit db %v", err))
+ }
+ // Re-create the trie based on the new state
+ trie, _ = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
+ return triedb, trie, content
+}
+
+func TestSecureDelete(t *testing.T) {
+ trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
+ if err != nil {
+ t.Fatal(err)
+ }
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ trie.Update([]byte(val.k), []byte(val.v))
+ } else {
+ trie.Delete([]byte(val.k))
+ }
+ }
+ hash := trie.Hash()
+ exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestSecureGetKey(t *testing.T) {
+ trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec)
+ if err != nil {
+ t.Fatal(err)
+ }
+ trie.Update([]byte("foo"), []byte("bar"))
+
+ key := []byte("foo")
+ value := []byte("bar")
+ seckey := crypto.Keccak256(key)
+
+ if !bytes.Equal(trie.Get(key), value) {
+ t.Errorf("Get did not return bar")
+ }
+ if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
+ t.Errorf("GetKey returned %q, want %q", k, key)
+ }
+}
+
+func TestStateTrieConcurrency(t *testing.T) {
+ // Create an initial trie and copy if for concurrent access
+ _, trie, _ := makeTestStateTrie()
+
+ threads := runtime.NumCPU()
+ tries := make([]*StateTrie, threads)
+ for i := 0; i < threads; i++ {
+ tries[i] = trie.Copy()
+ }
+ // Start a batch of goroutines interacting with the trie
+ pend := new(sync.WaitGroup)
+ pend.Add(threads)
+ for i := 0; i < threads; i++ {
+ go func(index int) {
+ defer pend.Done()
+
+ for j := byte(0); j < 255; j++ {
+ // Map the same data under multiple keys
+ key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
+ tries[index].Update(key, val)
+
+ key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
+ tries[index].Update(key, val)
+
+ // Add some other data to inflate the trie
+ for k := byte(3); k < 13; k++ {
+ key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
+ tries[index].Update(key, val)
+ }
+ }
+ tries[index].Commit(false)
+ }(i)
+ }
+ // Wait for all threads to finish
+ pend.Wait()
+}
diff --git a/trie_by_cid/trie/tracer.go b/trie_by_cid/trie/tracer.go
new file mode 100644
index 0000000..a27e371
--- /dev/null
+++ b/trie_by_cid/trie/tracer.go
@@ -0,0 +1,125 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import "github.com/ethereum/go-ethereum/common"
+
+// tracer tracks the changes of trie nodes. During the trie operations,
+// some nodes can be deleted from the trie, while these deleted nodes
+// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted
+// nodes won't be removed from the disk at all. Tracer is an auxiliary tool
+// used to track all insert and delete operations of trie and capture all
+// deleted nodes eventually.
+//
+// The changed nodes can be mainly divided into two categories: the leaf
+// node and intermediate node. The former is inserted/deleted by callers
+// while the latter is inserted/deleted in order to follow the rule of trie.
+// This tool can track all of them no matter the node is embedded in its
+// parent or not, but valueNode is never tracked.
+//
+// Besides, it's also used for recording the original value of the nodes
+// when they are resolved from the disk. The pre-value of the nodes will
+// be used to construct trie history in the future.
+//
+// Note tracer is not thread-safe, callers should be responsible for handling
+// the concurrency issues by themselves.
+type tracer struct {
+ inserts map[string]struct{}
+ deletes map[string]struct{}
+ accessList map[string][]byte
+}
+
+// newTracer initializes the tracer for capturing trie changes.
+func newTracer() *tracer {
+ return &tracer{
+ inserts: make(map[string]struct{}),
+ deletes: make(map[string]struct{}),
+ accessList: make(map[string][]byte),
+ }
+}
+
+// onRead tracks the newly loaded trie node and caches the rlp-encoded
+// blob internally. Don't change the value outside of function since
+// it's not deep-copied.
+func (t *tracer) onRead(path []byte, val []byte) {
+ t.accessList[string(path)] = val
+}
+
+// onInsert tracks the newly inserted trie node. If it's already
+// in the deletion set (resurrected node), then just wipe it from
+// the deletion set as it's "untouched".
+func (t *tracer) onInsert(path []byte) {
+ if _, present := t.deletes[string(path)]; present {
+ delete(t.deletes, string(path))
+ return
+ }
+ t.inserts[string(path)] = struct{}{}
+}
+
+// onDelete tracks the newly deleted trie node. If it's already
+// in the addition set, then just wipe it from the addition set
+// as it's untouched.
+func (t *tracer) onDelete(path []byte) {
+ if _, present := t.inserts[string(path)]; present {
+ delete(t.inserts, string(path))
+ return
+ }
+ t.deletes[string(path)] = struct{}{}
+}
+
+// reset clears the content tracked by tracer.
+func (t *tracer) reset() {
+ t.inserts = make(map[string]struct{})
+ t.deletes = make(map[string]struct{})
+ t.accessList = make(map[string][]byte)
+}
+
+// copy returns a deep copied tracer instance.
+func (t *tracer) copy() *tracer {
+ var (
+ inserts = make(map[string]struct{})
+ deletes = make(map[string]struct{})
+ accessList = make(map[string][]byte)
+ )
+ for path := range t.inserts {
+ inserts[path] = struct{}{}
+ }
+ for path := range t.deletes {
+ deletes[path] = struct{}{}
+ }
+ for path, blob := range t.accessList {
+ accessList[path] = common.CopyBytes(blob)
+ }
+ return &tracer{
+ inserts: inserts,
+ deletes: deletes,
+ accessList: accessList,
+ }
+}
+
+// markDeletions puts all tracked deletions into the provided nodeset.
+func (t *tracer) markDeletions(set *NodeSet) {
+ for path := range t.deletes {
+ // It's possible a few deleted nodes were embedded
+ // in their parent before, the deletions can be no
+ // effect by deleting nothing, filter them out.
+ if _, ok := set.accessList[path]; !ok {
+ continue
+ }
+ set.markDeleted([]byte(path))
+ }
+}
diff --git a/trie_by_cid/trie/trie.go b/trie_by_cid/trie/trie.go
index b2d8986..5ad22c1 100644
--- a/trie_by_cid/trie/trie.go
+++ b/trie_by_cid/trie/trie.go
@@ -7,18 +7,15 @@ import (
"fmt"
"github.com/ethereum/go-ethereum/common"
- "github.com/ethereum/go-ethereum/crypto"
+ "github.com/ethereum/go-ethereum/core/types"
+ log "github.com/sirupsen/logrus"
- util "github.com/cerc-io/ipld-eth-statedb/internal"
"github.com/ethereum/go-ethereum/statediff/indexer/ipld"
)
var (
- // emptyRoot is the known root hash of an empty trie.
- emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
-
- // emptyState is the known hash of an empty state trie entry.
- emptyState = crypto.Keccak256Hash(nil)
+ StateTrieCodec uint64 = ipld.MEthStateTrie
+ StorageTrieCodec uint64 = ipld.MEthStorageTrie
)
// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
@@ -37,29 +34,48 @@ type Trie struct {
// actually unhashed nodes.
unhashed int
- // db is the handler trie can retrieve nodes from. It's
- // only for reading purpose and not available for writing.
- db *Database
+ // reader is the handler trie can retrieve nodes from.
+ reader *trieReader
- // Multihash codec for key encoding
- codec uint64
+ // tracer is the tool to track the trie changes.
+ // It will be reset after each commit operation.
+ tracer *tracer
}
-// New creates a trie with an existing root node from db and an assigned
-// owner for storage proximity.
-//
-// If root is the zero hash or the sha3 hash of an empty string, the
-// trie is initially empty and does not require a database. Otherwise,
-// New will panic if db is nil and returns a MissingNodeError if root does
-// not exist in the database. Accessing the trie loads nodes from db on demand.
-func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie, error) {
- trie := &Trie{
- owner: owner,
- db: db,
- codec: codec,
+// newFlag returns the cache flag value for a newly created node.
+func (t *Trie) newFlag() nodeFlag {
+ return nodeFlag{dirty: true}
+}
+
+// Copy returns a copy of Trie.
+func (t *Trie) Copy() *Trie {
+ return &Trie{
+ root: t.root,
+ owner: t.owner,
+ unhashed: t.unhashed,
+ reader: t.reader,
+ tracer: t.tracer.copy(),
}
- if root != (common.Hash{}) && root != emptyRoot {
- rootnode, err := trie.resolveHash(root[:], nil)
+}
+
+// New creates a trie instance with the provided trie id and the read-only
+// database. The state specified by trie id must be available, otherwise
+// an error will be returned. The trie root specified by trie id can be
+// zero hash or the sha3 hash of an empty string, then trie is initially
+// empty, otherwise, the root node must be present in database or returns
+// a MissingNodeError if not.
+func New(id *ID, db NodeReader, codec uint64) (*Trie, error) {
+ reader, err := newTrieReader(id.StateRoot, id.Owner, db, codec)
+ if err != nil {
+ return nil, err
+ }
+ trie := &Trie{
+ owner: id.Owner,
+ reader: reader,
+ tracer: newTracer(),
+ }
+ if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash {
+ rootnode, err := trie.resolveAndTrack(id.Root[:], nil)
if err != nil {
return nil, err
}
@@ -70,7 +86,10 @@ func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie
// NewEmpty is a shortcut to create empty tree. It's mostly used in tests.
func NewEmpty(db *Database) *Trie {
- tr, _ := New(common.Hash{}, common.Hash{}, db, ipld.MEthStateTrie)
+ tr, err := New(TrieID(common.Hash{}), db, StateTrieCodec)
+ if err != nil {
+ panic(err)
+ }
return tr
}
@@ -80,6 +99,16 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator {
return newNodeIterator(t, start)
}
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *Trie) Get(key []byte) []byte {
+ res, err := t.TryGet(key)
+ if err != nil {
+ log.Error("Unhandled trie error in Trie.Get", "err", err)
+ }
+ return res
+}
+
// TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned.
@@ -116,7 +145,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
}
return value, n, didResolve, err
case hashNode:
- child, err := t.resolveHash(n, key[:pos])
+ child, err := t.resolveAndTrack(n, key[:pos])
if err != nil {
return nil, n, true, err
}
@@ -130,7 +159,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles.
func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
- item, newroot, resolved, err := t.tryGetNode(t.root, CompactToHex(path), 0)
+ item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0)
if err != nil {
return nil, resolved, err
}
@@ -162,11 +191,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
if hash == nil {
return nil, origNode, 0, errors.New("non-consensus node")
}
- cid, err := util.Keccak256ToCid(t.codec, hash)
- if err != nil {
- return nil, origNode, 0, err
- }
- blob, err := t.db.Node(cid)
+ blob, err := t.reader.nodeBlob(path, common.BytesToHash(hash))
return blob, origNode, 1, err
}
// Path still needs to be traversed, descend into children
@@ -196,7 +221,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
return item, n, resolved, err
case hashNode:
- child, err := t.resolveHash(n, path[:pos])
+ child, err := t.resolveAndTrack(n, path[:pos])
if err != nil {
return nil, n, 1, err
}
@@ -208,42 +233,309 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new
}
}
-// resolveHash loads node from the underlying database with the provided
-// node hash and path prefix.
-func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
- cid, err := util.Keccak256ToCid(t.codec, n)
- if err != nil {
- return nil, err
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *Trie) Update(key, value []byte) {
+ if err := t.TryUpdate(key, value); err != nil {
+ log.Error("Unhandled trie error in Trie.Update", "err", err)
}
- enc, err := t.db.Node(cid)
- if err != nil {
- return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix, err: err}
- }
- node, err := decodeNodeUnsafe(n, enc)
- if err != nil {
- return nil, err
- }
- if node != nil {
- return node, nil
- }
- return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix}
}
-// resolveHash loads rlp-encoded node blob from the underlying database
-// with the provided node hash and path prefix.
-func (t *Trie) resolveBlob(n hashNode, prefix []byte) ([]byte, error) {
- cid, err := util.Keccak256ToCid(t.codec, n)
+// TryUpdate associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+//
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryUpdate(key, value []byte) error {
+ return t.tryUpdate(key, value)
+}
+
+// tryUpdate expects an RLP-encoded value and performs the core function
+// for TryUpdate and TryUpdateAccount.
+func (t *Trie) tryUpdate(key, value []byte) error {
+ t.unhashed++
+ k := keybytesToHex(key)
+ if len(value) != 0 {
+ _, n, err := t.insert(t.root, nil, k, valueNode(value))
+ if err != nil {
+ return err
+ }
+ t.root = n
+ } else {
+ _, n, err := t.delete(t.root, nil, k)
+ if err != nil {
+ return err
+ }
+ t.root = n
+ }
+ return nil
+}
+
+func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) {
+ if len(key) == 0 {
+ if v, ok := n.(valueNode); ok {
+ return !bytes.Equal(v, value.(valueNode)), value, nil
+ }
+ return true, value, nil
+ }
+ switch n := n.(type) {
+ case *shortNode:
+ matchlen := prefixLen(key, n.Key)
+ // If the whole key matches, keep this short node as is
+ // and only update the value.
+ if matchlen == len(n.Key) {
+ dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
+ if !dirty || err != nil {
+ return false, n, err
+ }
+ return true, &shortNode{n.Key, nn, t.newFlag()}, nil
+ }
+ // Otherwise branch out at the index where they differ.
+ branch := &fullNode{flags: t.newFlag()}
+ var err error
+ _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val)
+ if err != nil {
+ return false, nil, err
+ }
+ _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value)
+ if err != nil {
+ return false, nil, err
+ }
+ // Replace this shortNode with the branch if it occurs at index 0.
+ if matchlen == 0 {
+ return true, branch, nil
+ }
+ // New branch node is created as a child of the original short node.
+ // Track the newly inserted node in the tracer. The node identifier
+ // passed is the path from the root node.
+ t.tracer.onInsert(append(prefix, key[:matchlen]...))
+
+ // Replace it with a short node leading up to the branch.
+ return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
+
+ case *fullNode:
+ dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
+ if !dirty || err != nil {
+ return false, n, err
+ }
+ n = n.copy()
+ n.flags = t.newFlag()
+ n.Children[key[0]] = nn
+ return true, n, nil
+
+ case nil:
+ // New short node is created and track it in the tracer. The node identifier
+ // passed is the path from the root node. Note the valueNode won't be tracked
+ // since it's always embedded in its parent.
+ t.tracer.onInsert(prefix)
+
+ return true, &shortNode{key, value, t.newFlag()}, nil
+
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and insert into it. This leaves all child nodes on
+ // the path to the value in the trie.
+ rn, err := t.resolveAndTrack(n, prefix)
+ if err != nil {
+ return false, nil, err
+ }
+ dirty, nn, err := t.insert(rn, prefix, key, value)
+ if !dirty || err != nil {
+ return false, rn, err
+ }
+ return true, nn, nil
+
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+}
+
+// Delete removes any existing value for key from the trie.
+func (t *Trie) Delete(key []byte) {
+ if err := t.TryDelete(key); err != nil {
+ log.Error("Unhandled trie error in Trie.Delete", "err", err)
+ }
+}
+
+// TryDelete removes any existing value for key from the trie.
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryDelete(key []byte) error {
+ t.unhashed++
+ k := keybytesToHex(key)
+ _, n, err := t.delete(t.root, nil, k)
+ if err != nil {
+ return err
+ }
+ t.root = n
+ return nil
+}
+
+// delete returns the new root of the trie with key deleted.
+// It reduces the trie to minimal form by simplifying
+// nodes on the way up after deleting recursively.
+func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
+ switch n := n.(type) {
+ case *shortNode:
+ matchlen := prefixLen(key, n.Key)
+ if matchlen < len(n.Key) {
+ return false, n, nil // don't replace n on mismatch
+ }
+ if matchlen == len(key) {
+ // The matched short node is deleted entirely and track
+ // it in the deletion set. The same the valueNode doesn't
+ // need to be tracked at all since it's always embedded.
+ t.tracer.onDelete(prefix)
+
+ return true, nil, nil // remove n entirely for whole matches
+ }
+ // The key is longer than n.Key. Remove the remaining suffix
+ // from the subtrie. Child can never be nil here since the
+ // subtrie must contain at least two other values with keys
+ // longer than n.Key.
+ dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
+ if !dirty || err != nil {
+ return false, n, err
+ }
+ switch child := child.(type) {
+ case *shortNode:
+ // The child shortNode is merged into its parent, track
+ // is deleted as well.
+ t.tracer.onDelete(append(prefix, n.Key...))
+
+ // Deleting from the subtrie reduced it to another
+ // short node. Merge the nodes to avoid creating a
+ // shortNode{..., shortNode{...}}. Use concat (which
+ // always creates a new slice) instead of append to
+ // avoid modifying n.Key since it might be shared with
+ // other nodes.
+ return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil
+ default:
+ return true, &shortNode{n.Key, child, t.newFlag()}, nil
+ }
+
+ case *fullNode:
+ dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
+ if !dirty || err != nil {
+ return false, n, err
+ }
+ n = n.copy()
+ n.flags = t.newFlag()
+ n.Children[key[0]] = nn
+
+ // Because n is a full node, it must've contained at least two children
+ // before the delete operation. If the new child value is non-nil, n still
+ // has at least two children after the deletion, and cannot be reduced to
+ // a short node.
+ if nn != nil {
+ return true, n, nil
+ }
+ // Reduction:
+ // Check how many non-nil entries are left after deleting and
+ // reduce the full node to a short node if only one entry is
+ // left. Since n must've contained at least two children
+ // before deletion (otherwise it would not be a full node) n
+ // can never be reduced to nil.
+ //
+ // When the loop is done, pos contains the index of the single
+ // value that is left in n or -2 if n contains at least two
+ // values.
+ pos := -1
+ for i, cld := range &n.Children {
+ if cld != nil {
+ if pos == -1 {
+ pos = i
+ } else {
+ pos = -2
+ break
+ }
+ }
+ }
+ if pos >= 0 {
+ if pos != 16 {
+ // If the remaining entry is a short node, it replaces
+ // n and its key gets the missing nibble tacked to the
+ // front. This avoids creating an invalid
+ // shortNode{..., shortNode{...}}. Since the entry
+ // might not be loaded yet, resolve it just for this
+ // check.
+ cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos)))
+ if err != nil {
+ return false, nil, err
+ }
+ if cnode, ok := cnode.(*shortNode); ok {
+ // Replace the entire full node with the short node.
+ // Mark the original short node as deleted since the
+ // value is embedded into the parent now.
+ t.tracer.onDelete(append(prefix, byte(pos)))
+
+ k := append([]byte{byte(pos)}, cnode.Key...)
+ return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
+ }
+ }
+ // Otherwise, n is replaced by a one-nibble short node
+ // containing the child.
+ return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
+ }
+ // n still contains at least two values and cannot be reduced.
+ return true, n, nil
+
+ case valueNode:
+ return true, nil, nil
+
+ case nil:
+ return false, nil, nil
+
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and delete from it. This leaves all child nodes on
+ // the path to the value in the trie.
+ rn, err := t.resolveAndTrack(n, prefix)
+ if err != nil {
+ return false, nil, err
+ }
+ dirty, nn, err := t.delete(rn, prefix, key)
+ if !dirty || err != nil {
+ return false, rn, err
+ }
+ return true, nn, nil
+
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
+ }
+}
+
+func concat(s1 []byte, s2 ...byte) []byte {
+ r := make([]byte, len(s1)+len(s2))
+ copy(r, s1)
+ copy(r[len(s1):], s2)
+ return r
+}
+
+func (t *Trie) resolve(n node, prefix []byte) (node, error) {
+ if n, ok := n.(hashNode); ok {
+ return t.resolveAndTrack(n, prefix)
+ }
+ return n, nil
+}
+
+// resolveAndTrack loads node from the underlying store with the given node hash
+// and path prefix and also tracks the loaded node blob in tracer treated as the
+// node's original value. The rlp-encoded blob is preferred to be loaded from
+// database because it's easy to decode node while complex to encode node to blob.
+func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
+ blob, err := t.reader.nodeBlob(prefix, common.BytesToHash(n))
if err != nil {
return nil, err
}
- blob, err := t.db.Node(cid)
- if err != nil {
- return nil, err
- }
- if len(blob) != 0 {
- return blob, nil
- }
- return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix}
+ t.tracer.onRead(prefix, blob)
+ return mustDecodeNode(n, blob), nil
}
// Hash returns the root hash of the trie. It does not write to the
@@ -254,15 +546,59 @@ func (t *Trie) Hash() common.Hash {
return common.BytesToHash(hash.(hashNode))
}
+// Commit collects all dirty nodes in the trie and replaces them with the
+// corresponding node hash. All collected nodes (including dirty leaves if
+// collectLeaf is true) will be encapsulated into a nodeset for return.
+// The returned nodeset can be nil if the trie is clean (nothing to commit).
+// Once the trie is committed, it's not usable anymore. A new trie must
+// be created with new root and updated trie database for following usage
+func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) {
+ defer t.tracer.reset()
+
+ nodes := NewNodeSet(t.owner, t.tracer.accessList)
+ t.tracer.markDeletions(nodes)
+
+ // Trie is empty and can be classified into two types of situations:
+ // - The trie was empty and no update happens
+ // - The trie was non-empty and all nodes are dropped
+ if t.root == nil {
+ return types.EmptyRootHash, nodes
+ }
+ // Derive the hash for all dirty nodes first. We hold the assumption
+ // in the following procedure that all nodes are hashed.
+ rootHash := t.Hash()
+
+ // Do a quick check if we really need to commit. This can happen e.g.
+ // if we load a trie for reading storage values, but don't write to it.
+ if hashedNode, dirty := t.root.cache(); !dirty {
+ // Replace the root node with the origin hash in order to
+ // ensure all resolved nodes are dropped after the commit.
+ t.root = hashedNode
+ return rootHash, nil
+ }
+ t.root = newCommitter(nodes, collectLeaf).Commit(t.root)
+ return rootHash, nodes
+}
+
// hashRoot calculates the root hash of the given trie
func (t *Trie) hashRoot() (node, node) {
if t.root == nil {
- return hashNode(emptyRoot.Bytes()), nil
+ return hashNode(types.EmptyRootHash.Bytes()), nil
}
// If the number of changes is below 100, we let one thread handle it
h := newHasher(t.unhashed >= 100)
- defer returnHasherToPool(h)
+ defer func() {
+ returnHasherToPool(h)
+ t.unhashed = 0
+ }()
hashed, cached := h.hash(t.root, true)
- t.unhashed = 0
return hashed, cached
}
+
+// Reset drops the referenced root node and cleans all internal state.
+func (t *Trie) Reset() {
+ t.root = nil
+ t.owner = common.Hash{}
+ t.unhashed = 0
+ t.tracer.reset()
+}
diff --git a/trie_by_cid/trie/trie_id.go b/trie_by_cid/trie/trie_id.go
new file mode 100644
index 0000000..8ab490c
--- /dev/null
+++ b/trie_by_cid/trie/trie_id.go
@@ -0,0 +1,55 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package trie
+
+import "github.com/ethereum/go-ethereum/common"
+
+// ID is the identifier for uniquely identifying a trie.
+type ID struct {
+ StateRoot common.Hash // The root of the corresponding state(block.root)
+ Owner common.Hash // The contract address hash which the trie belongs to
+ Root common.Hash // The root hash of trie
+}
+
+// StateTrieID constructs an identifier for state trie with the provided state root.
+func StateTrieID(root common.Hash) *ID {
+ return &ID{
+ StateRoot: root,
+ Owner: common.Hash{},
+ Root: root,
+ }
+}
+
+// StorageTrieID constructs an identifier for storage trie which belongs to a certain
+// state and contract specified by the stateRoot and owner.
+func StorageTrieID(stateRoot common.Hash, owner common.Hash, root common.Hash) *ID {
+ return &ID{
+ StateRoot: stateRoot,
+ Owner: owner,
+ Root: root,
+ }
+}
+
+// TrieID constructs an identifier for a standard trie(not a second-layer trie)
+// with provided root. It's mostly used in tests and some other tries like CHT trie.
+func TrieID(root common.Hash) *ID {
+ return &ID{
+ StateRoot: root,
+ Owner: common.Hash{},
+ Root: root,
+ }
+}
diff --git a/trie_by_cid/trie/trie_reader.go b/trie_by_cid/trie/trie_reader.go
new file mode 100644
index 0000000..b0a7fdd
--- /dev/null
+++ b/trie_by_cid/trie/trie_reader.go
@@ -0,0 +1,106 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+
+ "github.com/ethereum/go-ethereum/common"
+)
+
+// Reader wraps the Node and NodeBlob method of a backing trie store.
+type Reader interface {
+ // Node retrieves the trie node with the provided trie identifier, hexary
+ // node path and the corresponding node hash.
+ // No error will be returned if the node is not found.
+ Node(owner common.Hash, path []byte, hash common.Hash) (node, error)
+
+ // NodeBlob retrieves the RLP-encoded trie node blob with the provided trie
+ // identifier, hexary node path and the corresponding node hash.
+ // No error will be returned if the node is not found.
+ NodeBlob(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)
+}
+
+// NodeReader wraps all the necessary functions for accessing trie node.
+type NodeReader interface {
+ // GetReader returns a reader for accessing all trie nodes with provided
+ // state root. Nil is returned in case the state is not available.
+ GetReader(root common.Hash, codec uint64) Reader
+}
+
+// trieReader is a wrapper of the underlying node reader. It's not safe
+// for concurrent usage.
+type trieReader struct {
+ owner common.Hash
+ reader Reader
+ banned map[string]struct{} // Marker to prevent node from being accessed, for tests
+}
+
+// newTrieReader initializes the trie reader with the given node reader.
+func newTrieReader(stateRoot, owner common.Hash, db NodeReader, codec uint64) (*trieReader, error) {
+ reader := db.GetReader(stateRoot, codec)
+ if reader == nil {
+ return nil, fmt.Errorf("state not found #%x", stateRoot)
+ }
+ return &trieReader{owner: owner, reader: reader}, nil
+}
+
+// newEmptyReader initializes the pure in-memory reader. All read operations
+// should be forbidden and returns the MissingNodeError.
+func newEmptyReader() *trieReader {
+ return &trieReader{}
+}
+
+// node retrieves the trie node with the provided trie node information.
+// An MissingNodeError will be returned in case the node is not found or
+// any error is encountered.
+func (r *trieReader) node(path []byte, hash common.Hash) (node, error) {
+ // Perform the logics in tests for preventing trie node access.
+ if r.banned != nil {
+ if _, ok := r.banned[string(path)]; ok {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+ }
+ if r.reader == nil {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+ node, err := r.reader.Node(r.owner, path, hash)
+ if err != nil || node == nil {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
+ }
+ return node, nil
+}
+
+// node retrieves the rlp-encoded trie node with the provided trie node
+// information. An MissingNodeError will be returned in case the node is
+// not found or any error is encountered.
+func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) {
+ // Perform the logics in tests for preventing trie node access.
+ if r.banned != nil {
+ if _, ok := r.banned[string(path)]; ok {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+ }
+ if r.reader == nil {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+ blob, err := r.reader.NodeBlob(r.owner, path, hash)
+ if err != nil || len(blob) == 0 {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
+ }
+ return blob, nil
+}
diff --git a/trie_by_cid/trie/trie_test.go b/trie_by_cid/trie/trie_test.go
index 0418409..76aaf72 100644
--- a/trie_by_cid/trie/trie_test.go
+++ b/trie_by_cid/trie/trie_test.go
@@ -14,26 +14,34 @@
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see .
-package trie_test
+package trie
import (
+ "bytes"
+ "encoding/binary"
+ "errors"
"fmt"
"math/big"
"math/rand"
+ "reflect"
"testing"
+ "testing/quick"
+ "github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
- geth_trie "github.com/ethereum/go-ethereum/trie"
-
- "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
)
-func TestTrieEmpty(t *testing.T) {
- trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase()))
+func init() {
+ spew.Config.Indent = " "
+ spew.Config.DisableMethods = false
+}
+
+func TestEmptyTrie(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
res := trie.Hash()
exp := types.EmptyRootHash
if res != exp {
@@ -41,42 +49,543 @@ func TestTrieEmpty(t *testing.T) {
}
}
-func TestTrieMissingRoot(t *testing.T) {
+func TestNull(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ key := make([]byte, 32)
+ value := []byte("test")
+ trie.Update(key, value)
+ if !bytes.Equal(trie.Get(key), value) {
+ t.Fatal("wrong value")
+ }
+}
+
+func TestMissingRoot(t *testing.T) {
root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33")
- tr, err := newStateTrie(root, trie.NewDatabase(rawdb.NewMemoryDatabase()))
- if tr != nil {
+ trie, err := NewAccountTrie(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase()))
+ if trie != nil {
t.Error("New returned non-nil trie for invalid root")
}
- if _, ok := err.(*trie.MissingNodeError); !ok {
+ if _, ok := err.(*MissingNodeError); !ok {
t.Errorf("New returned wrong error: %v", err)
}
}
-func TestTrieBasic(t *testing.T) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- origtrie := geth_trie.NewEmpty(db)
- origtrie.Update([]byte("foo"), packValue(842))
- expected := commitTrie(t, db, origtrie)
- tr := indexTrie(t, edb, expected)
- got := tr.Hash()
- if expected != got {
- t.Errorf("got %x expected %x", got, expected)
+func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
+
+func testMissingNode(t *testing.T, memonly bool) {
+ diskdb := rawdb.NewMemoryDatabase()
+ triedb := NewDatabase(diskdb)
+
+ trie := NewEmpty(triedb)
+ updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
+ updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
+ root, nodes := trie.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ _, err := trie.TryGet([]byte("120000"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ _, err = trie.TryGet([]byte("120099"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ _, err = trie.TryGet([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ err = trie.TryDelete([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
+ if memonly {
+ delete(triedb.dirties, hash)
+ } else {
+ diskdb.Delete(hash[:])
+ }
+
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ _, err = trie.TryGet([]byte("120000"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ _, err = trie.TryGet([]byte("120099"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ _, err = trie.TryGet([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+ trie, _ = NewAccountTrie(TrieID(root), triedb)
+ err = trie.TryDelete([]byte("123456"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
}
- checkValue(t, tr, []byte("foo"))
}
-func TestTrieTiny(t *testing.T) {
+func TestInsert(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
+
+ exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
+ root := trie.Hash()
+ if root != exp {
+ t.Errorf("case 1: exp %x got %x", exp, root)
+ }
+
+ trie = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
+
+ exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
+ root, _ = trie.Commit(false)
+ if root != exp {
+ t.Errorf("case 2: exp %x got %x", exp, root)
+ }
+}
+
+func TestGet(t *testing.T) {
+ db := NewDatabase(rawdb.NewMemoryDatabase())
+ trie := NewEmpty(db)
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
+
+ for i := 0; i < 2; i++ {
+ res := getString(trie, "dog")
+ if !bytes.Equal(res, []byte("puppy")) {
+ t.Errorf("expected puppy got %x", res)
+ }
+ unknown := getString(trie, "unknown")
+ if unknown != nil {
+ t.Errorf("expected nil got %x", unknown)
+ }
+ if i == 1 {
+ return
+ }
+ root, nodes := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+ trie, _ = NewAccountTrie(TrieID(root), db)
+ }
+}
+
+func TestDelete(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ updateString(trie, val.k, val.v)
+ } else {
+ deleteString(trie, val.k)
+ }
+ }
+
+ hash := trie.Hash()
+ exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestEmptyValues(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ updateString(trie, val.k, val.v)
+ }
+
+ hash := trie.Hash()
+ exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestReplication(t *testing.T) {
+ triedb := NewDatabase(rawdb.NewMemoryDatabase())
+ trie := NewEmpty(triedb)
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ updateString(trie, val.k, val.v)
+ }
+ exp, nodes := trie.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+
+ // create a new trie on top of the database and check that lookups work.
+ trie2, err := NewAccountTrie(TrieID(exp), triedb)
+ if err != nil {
+ t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ }
+ for _, kv := range vals {
+ if string(getString(trie2, kv.k)) != kv.v {
+ t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
+ }
+ }
+ hash, nodes := trie2.Commit(false)
+ if hash != exp {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+
+ // recreate the trie after commit
+ if nodes != nil {
+ triedb.Update(NewWithNodeSet(nodes))
+ }
+ trie2, err = NewAccountTrie(TrieID(hash), triedb)
+ if err != nil {
+ t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ }
+ // perform some insertions on the new trie.
+ vals2 := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ // {"shaman", "horse"},
+ // {"doge", "coin"},
+ // {"ether", ""},
+ // {"dog", "puppy"},
+ // {"somethingveryoddindeedthis is", "myothernodedata"},
+ // {"shaman", ""},
+ }
+ for _, val := range vals2 {
+ updateString(trie2, val.k, val.v)
+ }
+ if hash := trie2.Hash(); hash != exp {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+}
+
+func TestLargeValue(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
+ trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
+ trie.Hash()
+}
+
+// TestRandomCases tests some cases that were found via random fuzzing
+func TestRandomCases(t *testing.T) {
+ var rt = []randTestStep{
+ {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 0
+ {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 1
+ {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000002")}, // step 2
+ {op: 2, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 3
+ {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 4
+ {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 5
+ {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 6
+ {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 7
+ {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000008")}, // step 8
+ {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000009")}, // step 9
+ {op: 2, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 10
+ {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 11
+ {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 12
+ {op: 0, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("000000000000000d")}, // step 13
+ {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 14
+ {op: 1, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 15
+ {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 16
+ {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000011")}, // step 17
+ {op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 18
+ {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 19
+ {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000014")}, // step 20
+ {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000015")}, // step 21
+ {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000016")}, // step 22
+ {op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 23
+ {op: 1, key: common.Hex2Bytes("980c393656413a15c8da01978ed9f89feb80b502f58f2d640e3a2f5f7a99a7018f1b573befd92053ac6f78fca4a87268"), value: common.Hex2Bytes("")}, // step 24
+ {op: 1, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 25
+ }
+ runRandTest(rt)
+}
+
+// randTest performs random trie operations.
+// Instances of this test are created by Generate.
+type randTest []randTestStep
+
+type randTestStep struct {
+ op int
+ key []byte // for opUpdate, opDelete, opGet
+ value []byte // for opUpdate
+ err error // for debugging
+}
+
+const (
+ opUpdate = iota
+ opDelete
+ opGet
+ opHash
+ opCommit
+ opItercheckhash
+ opNodeDiff
+ opProve
+ opMax // boundary value, not an actual op
+)
+
+func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
+ var allKeys [][]byte
+ genKey := func() []byte {
+ if len(allKeys) < 2 || r.Intn(100) < 10 {
+ // new key
+ key := make([]byte, r.Intn(50))
+ r.Read(key)
+ allKeys = append(allKeys, key)
+ return key
+ }
+ // use existing key
+ return allKeys[r.Intn(len(allKeys))]
+ }
+
+ var steps randTest
+ for i := 0; i < size; i++ {
+ step := randTestStep{op: r.Intn(opMax)}
+ switch step.op {
+ case opUpdate:
+ step.key = genKey()
+ step.value = make([]byte, 8)
+ binary.BigEndian.PutUint64(step.value, uint64(i))
+ case opGet, opDelete, opProve:
+ step.key = genKey()
+ }
+ steps = append(steps, step)
+ }
+ return reflect.ValueOf(steps)
+}
+
+func verifyAccessList(old *Trie, new *Trie, set *NodeSet) error {
+ deletes, inserts, updates := diffTries(old, new)
+
+ // Check insertion set
+ for path := range inserts {
+ n, ok := set.nodes[path]
+ if !ok || n.isDeleted() {
+ return errors.New("expect new node")
+ }
+ _, ok = set.accessList[path]
+ if ok {
+ return errors.New("unexpected origin value")
+ }
+ }
+ // Check deletion set
+ for path, blob := range deletes {
+ n, ok := set.nodes[path]
+ if !ok || !n.isDeleted() {
+ return errors.New("expect deleted node")
+ }
+ v, ok := set.accessList[path]
+ if !ok {
+ return errors.New("expect origin value")
+ }
+ if !bytes.Equal(v, blob) {
+ return errors.New("invalid origin value")
+ }
+ }
+ // Check update set
+ for path, blob := range updates {
+ n, ok := set.nodes[path]
+ if !ok || n.isDeleted() {
+ return errors.New("expect updated node")
+ }
+ v, ok := set.accessList[path]
+ if !ok {
+ return errors.New("expect origin value")
+ }
+ if !bytes.Equal(v, blob) {
+ return errors.New("invalid origin value")
+ }
+ }
+ return nil
+}
+
+func runRandTest(rt randTest) bool {
+ var (
+ triedb = NewDatabase(rawdb.NewMemoryDatabase())
+ tr = NewEmpty(triedb)
+ values = make(map[string]string) // tracks content of the trie
+ origTrie = NewEmpty(triedb)
+ )
+ for i, step := range rt {
+ // fmt.Printf("{op: %d, key: common.Hex2Bytes(\"%x\"), value: common.Hex2Bytes(\"%x\")}, // step %d\n",
+ // step.op, step.key, step.value, i)
+
+ switch step.op {
+ case opUpdate:
+ tr.Update(step.key, step.value)
+ values[string(step.key)] = string(step.value)
+ case opDelete:
+ tr.Delete(step.key)
+ delete(values, string(step.key))
+ case opGet:
+ v := tr.Get(step.key)
+ want := values[string(step.key)]
+ if string(v) != want {
+ rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want)
+ }
+ case opProve:
+ hash := tr.Hash()
+ if hash == types.EmptyRootHash {
+ continue
+ }
+ proofDb := rawdb.NewMemoryDatabase()
+ err := tr.Prove(step.key, 0, proofDb)
+ if err != nil {
+ rt[i].err = fmt.Errorf("failed for proving key %#x, %v", step.key, err)
+ }
+ _, err = VerifyProof(hash, step.key, proofDb)
+ if err != nil {
+ rt[i].err = fmt.Errorf("failed for verifying key %#x, %v", step.key, err)
+ }
+ case opHash:
+ tr.Hash()
+ case opCommit:
+ root, nodes := tr.Commit(true)
+ if nodes != nil {
+ triedb.Update(NewWithNodeSet(nodes))
+ }
+ newtr, err := NewAccountTrie(TrieID(root), triedb)
+ if err != nil {
+ rt[i].err = err
+ return false
+ }
+ if nodes != nil {
+ if err := verifyAccessList(origTrie, newtr, nodes); err != nil {
+ rt[i].err = err
+ return false
+ }
+ }
+ tr = newtr
+ origTrie = tr.Copy()
+ case opItercheckhash:
+ checktr := NewEmpty(triedb)
+ it := NewIterator(tr.NodeIterator(nil))
+ for it.Next() {
+ checktr.Update(it.Key, it.Value)
+ }
+ if tr.Hash() != checktr.Hash() {
+ rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
+ }
+ case opNodeDiff:
+ var (
+ origIter = origTrie.NodeIterator(nil)
+ curIter = tr.NodeIterator(nil)
+ origSeen = make(map[string]struct{})
+ curSeen = make(map[string]struct{})
+ )
+ for origIter.Next(true) {
+ if origIter.Leaf() {
+ continue
+ }
+ origSeen[string(origIter.Path())] = struct{}{}
+ }
+ for curIter.Next(true) {
+ if curIter.Leaf() {
+ continue
+ }
+ curSeen[string(curIter.Path())] = struct{}{}
+ }
+ var (
+ insertExp = make(map[string]struct{})
+ deleteExp = make(map[string]struct{})
+ )
+ for path := range curSeen {
+ _, present := origSeen[path]
+ if !present {
+ insertExp[path] = struct{}{}
+ }
+ }
+ for path := range origSeen {
+ _, present := curSeen[path]
+ if !present {
+ deleteExp[path] = struct{}{}
+ }
+ }
+ if len(insertExp) != len(tr.tracer.inserts) {
+ rt[i].err = fmt.Errorf("insert set mismatch")
+ }
+ if len(deleteExp) != len(tr.tracer.deletes) {
+ rt[i].err = fmt.Errorf("delete set mismatch")
+ }
+ for insert := range tr.tracer.inserts {
+ if _, present := insertExp[insert]; !present {
+ rt[i].err = fmt.Errorf("missing inserted node")
+ }
+ }
+ for del := range tr.tracer.deletes {
+ if _, present := deleteExp[del]; !present {
+ rt[i].err = fmt.Errorf("missing deleted node")
+ }
+ }
+ }
+ // Abort the test on error.
+ if rt[i].err != nil {
+ return false
+ }
+ }
+ return true
+}
+
+func TestRandom(t *testing.T) {
+ if err := quick.Check(runRandTest, nil); err != nil {
+ if cerr, ok := err.(*quick.CheckError); ok {
+ t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
+ }
+ t.Fatal(err)
+ }
+}
+
+func TestTinyTrie(t *testing.T) {
// Create a realistic account trie to hash
_, accounts := makeAccounts(5)
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- origtrie := geth_trie.NewEmpty(db)
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
type testCase struct {
key, account []byte
root common.Hash
}
+
cases := []testCase{
{
common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"),
@@ -92,48 +601,48 @@ func TestTrieTiny(t *testing.T) {
common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"),
},
}
- for i, tc := range cases {
- t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
- origtrie.Update(tc.key, tc.account)
- trie := indexTrie(t, edb, commitTrie(t, db, origtrie))
- if exp, root := tc.root, trie.Hash(); exp != root {
- t.Errorf("got %x, exp %x", root, exp)
- }
- checkValue(t, trie, tc.key)
- })
+ for i, c := range cases {
+ trie.Update(c.key, c.account)
+ root := trie.Hash()
+ if root != c.root {
+ t.Errorf("case %d: got %x, exp %x", i, root, c.root)
+ }
+ }
+ checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
+ it := NewIterator(trie.NodeIterator(nil))
+ for it.Next() {
+ checktr.Update(it.Key, it.Value)
+ }
+ if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot {
+ t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot)
}
}
-func TestTrieMedium(t *testing.T) {
+func TestCommitAfterHash(t *testing.T) {
// Create a realistic account trie to hash
addresses, accounts := makeAccounts(1000)
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- origtrie := geth_trie.NewEmpty(db)
- var keys [][]byte
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
for i := 0; i < len(addresses); i++ {
- key := crypto.Keccak256(addresses[i][:])
- if i%50 == 0 {
- keys = append(keys, key)
- }
- origtrie.Update(key, accounts[i])
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
}
- tr := indexTrie(t, edb, commitTrie(t, db, origtrie))
-
- root := tr.Hash()
+ // Insert the accounts into the trie and hash it
+ trie.Hash()
+ trie.Commit(false)
+ root := trie.Hash()
exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6")
if exp != root {
t.Errorf("got %x, exp %x", root, exp)
}
-
- for _, key := range keys {
- checkValue(t, tr, key)
+ root, _ = trie.Commit(false)
+ if exp != root {
+ t.Errorf("got %x, exp %x", root, exp)
}
}
-// Make deterministically random accounts
func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
+ // Make the random benchmark deterministic
random := rand.New(rand.NewSource(0))
+ // Create a realistic account trie to hash
addresses = make([][20]byte, size)
for i := 0; i < len(addresses); i++ {
data := make([]byte, 20)
@@ -149,25 +658,40 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
)
// The big.Rand function is not deterministic with regards to 64 vs 32 bit systems,
// and will consume different amount of data from the rand source.
- // balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+ //balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
// Therefore, we instead just read via byte buffer
numBytes := random.Uint32() % 33 // [0, 32] bytes
balanceBytes := make([]byte, numBytes)
random.Read(balanceBytes)
balance := new(big.Int).SetBytes(balanceBytes)
- acct := &types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code}
- data, _ := rlp.EncodeToBytes(acct)
+ data, _ := rlp.EncodeToBytes(&types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code})
accounts[i] = data
}
return addresses, accounts
}
-func checkValue(t *testing.T, tr *trie.Trie, key []byte) {
- val, err := tr.TryGet(key)
- if err != nil {
- t.Fatalf("error getting node: %s", err)
- }
- if len(val) == 0 {
- t.Errorf("failed to get value for %x", key)
+func getString(trie *Trie, k string) []byte {
+ return trie.Get([]byte(k))
+}
+
+func updateString(trie *Trie, k, v string) {
+ trie.Update([]byte(k), []byte(v))
+}
+
+func deleteString(trie *Trie, k string) {
+ trie.Delete([]byte(k))
+}
+
+func TestDecodeNode(t *testing.T) {
+ t.Parallel()
+
+ var (
+ hash = make([]byte, 20)
+ elems = make([]byte, 20)
+ )
+ for i := 0; i < 5000000; i++ {
+ prng.Read(hash)
+ prng.Read(elems)
+ decodeNode(hash, elems)
}
}
diff --git a/trie_by_cid/trie/util_test.go b/trie_by_cid/trie/util_test.go
index 3272826..a75ac8c 100644
--- a/trie_by_cid/trie/util_test.go
+++ b/trie_by_cid/trie/util_test.go
@@ -1,43 +1,133 @@
-package trie_test
+package trie
import (
+ "bytes"
+ "context"
"fmt"
"math/big"
"math/rand"
"testing"
- "time"
-
- "github.com/jmoiron/sqlx"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
- geth_state "github.com/ethereum/go-ethereum/core/state"
+ gethstate "github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/rlp"
- geth_trie "github.com/ethereum/go-ethereum/trie"
+ gethtrie "github.com/ethereum/go-ethereum/trie"
+ "github.com/jmoiron/sqlx"
pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0"
- "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
- "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/state"
- "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie"
"github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres"
- "github.com/ethereum/go-ethereum/statediff/indexer/ipld"
"github.com/ethereum/go-ethereum/statediff/test_helpers"
+
+ "github.com/cerc-io/ipld-eth-statedb/internal"
+ "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper"
)
-type kv struct {
+var (
+ dbConfig, _ = postgres.DefaultConfig.WithEnv()
+ trieConfig = Config{Cache: 256}
+)
+
+type kvi struct {
k []byte
v int64
}
-type kvMap map[string]*kv
+type kvMap map[string]*kvi
-type kvs struct {
+type kvsi struct {
k string
v int64
}
+// NewAccountTrie is a shortcut to create a trie using the StateTrieCodec (ie. IPLD MEthStateTrie codec).
+func NewAccountTrie(id *ID, db NodeReader) (*Trie, error) {
+ return New(id, db, StateTrieCodec)
+}
+
+// makeTestTrie create a sample test trie to test node-wise reconstruction.
+func makeTestTrie(t testing.TB) (*Database, *StateTrie, map[string][]byte) {
+ // Create an empty trie
+ triedb := NewDatabase(rawdb.NewMemoryDatabase())
+ trie, err := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Fill it with some arbitrary data
+ content := make(map[string][]byte)
+ for i := byte(0); i < 255; i++ {
+ // Map the same data under multiple keys
+ key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ // Add some other data to inflate the trie
+ for j := byte(3); j < 13; j++ {
+ key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
+ content[string(key)] = val
+ trie.Update(key, val)
+ }
+ }
+ root, nodes := trie.Commit(false)
+ if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
+ panic(fmt.Errorf("failed to commit db %v", err))
+ }
+ // Re-create the trie based on the new state
+ trie, err = NewStateTrie(TrieID(root), triedb, StateTrieCodec)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return triedb, trie, content
+}
+
+func forHashedNodes(tr *Trie) map[string][]byte {
+ var (
+ it = tr.NodeIterator(nil)
+ nodes = make(map[string][]byte)
+ )
+ for it.Next(true) {
+ if it.Hash() == (common.Hash{}) {
+ continue
+ }
+ nodes[string(it.Path())] = common.CopyBytes(it.NodeBlob())
+ }
+ return nodes
+}
+
+func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) {
+ var (
+ nodesA = forHashedNodes(trieA)
+ nodesB = forHashedNodes(trieB)
+ inA = make(map[string][]byte) // hashed nodes in trie a but not b
+ inB = make(map[string][]byte) // hashed nodes in trie b but not a
+ both = make(map[string][]byte) // hashed nodes in both tries but different value
+ )
+ for path, blobA := range nodesA {
+ if blobB, ok := nodesB[path]; ok {
+ if bytes.Equal(blobA, blobB) {
+ continue
+ }
+ both[path] = blobA
+ continue
+ }
+ inA[path] = blobA
+ }
+ for path, blobB := range nodesB {
+ if _, ok := nodesA[path]; ok {
+ continue
+ }
+ inB[path] = blobB
+ }
+ return inA, inB, both
+}
+
func packValue(val int64) []byte {
acct := &types.StateAccount{
Balance: big.NewInt(val),
@@ -51,27 +141,19 @@ func packValue(val int64) []byte {
return acct_rlp
}
-func unpackValue(val []byte) int64 {
- var acct types.StateAccount
- if err := rlp.DecodeBytes(val, &acct); err != nil {
- panic(err)
- }
- return acct.Balance.Int64()
-}
-
-func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) {
+func updateTrie(tr *gethtrie.Trie, vals []kvsi) (kvMap, error) {
all := kvMap{}
for _, val := range vals {
- all[string(val.k)] = &kv{[]byte(val.k), val.v}
+ all[string(val.k)] = &kvi{[]byte(val.k), val.v}
tr.Update([]byte(val.k), packValue(val.v))
}
return all, nil
}
-func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash {
+func commitTrie(t testing.TB, db *gethtrie.Database, tr *gethtrie.Trie) common.Hash {
t.Helper()
root, nodes := tr.Commit(false)
- if err := db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil {
+ if err := db.Update(gethtrie.NewWithNodeSet(nodes)); err != nil {
t.Fatal(err)
}
if err := db.Commit(root, false); err != nil {
@@ -80,16 +162,8 @@ func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common
return root
}
-// commit a LevelDB state trie, index to IPLD and return new trie
-func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie {
- t.Helper()
- dbConfig.Driver = postgres.PGX
- err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root)
- if err != nil {
- t.Fatal(err)
- }
-
- pg_db, err := postgres.ConnectSQLX(ctx, dbConfig)
+func makePgIpfsEthDB(t testing.TB) ethdb.Database {
+ pg_db, err := postgres.ConnectSQLX(context.Background(), dbConfig)
if err != nil {
t.Fatal(err)
}
@@ -98,62 +172,48 @@ func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie {
t.Fatal(err)
}
})
+ return pgipfsethdb.NewDatabase(pg_db, internal.MakeCacheConfig(t))
+}
- ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t))
- sdb_db := state.NewDatabase(ipfs_db)
- tr, err := newStateTrie(root, sdb_db.TrieDB())
+// commit a LevelDB state trie, index to IPLD and return new trie
+func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *Trie {
+ t.Helper()
+ dbConfig.Driver = postgres.PGX
+ err := helper.IndexStateDiff(dbConfig, gethstate.NewDatabase(edb), common.Hash{}, root)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ipfs_db := makePgIpfsEthDB(t)
+ tr, err := New(TrieID(root), NewDatabase(ipfs_db), StateTrieCodec)
if err != nil {
t.Fatal(err)
}
return tr
}
-func newStateTrie(root common.Hash, db *trie.Database) (*trie.Trie, error) {
- tr, err := trie.New(common.Hash{}, root, db, ipld.MEthStateTrie)
- if err != nil {
- return nil, err
- }
- return tr, nil
-}
-
// generates a random Geth LevelDB trie of n key-value pairs and corresponding value map
-func randomGethTrie(n int, db *geth_trie.Database) (*geth_trie.Trie, kvMap) {
- trie := geth_trie.NewEmpty(db)
- var vals []*kv
+func randomGethTrie(n int, db *gethtrie.Database) (*gethtrie.Trie, kvMap) {
+ trie := gethtrie.NewEmpty(db)
+ var vals []*kvi
for i := byte(0); i < 100; i++ {
- e := &kv{common.LeftPadBytes([]byte{i}, 32), int64(i)}
- e2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)}
+ e := &kvi{common.LeftPadBytes([]byte{i}, 32), int64(i)}
+ e2 := &kvi{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)}
vals = append(vals, e, e2)
}
for i := 0; i < n; i++ {
k := randBytes(32)
v := rand.Int63()
- vals = append(vals, &kv{k, v})
+ vals = append(vals, &kvi{k, v})
}
all := kvMap{}
for _, val := range vals {
- all[string(val.k)] = &kv{[]byte(val.k), val.v}
+ all[string(val.k)] = &kvi{[]byte(val.k), val.v}
trie.Update([]byte(val.k), packValue(val.v))
}
return trie, all
}
-// generates a random IPLD-indexed trie
-func randomTrie(t testing.TB, n int) (*trie.Trie, kvMap) {
- edb := rawdb.NewMemoryDatabase()
- db := geth_trie.NewDatabase(edb)
- orig, vals := randomGethTrie(n, db)
- root := commitTrie(t, db, orig)
- trie := indexTrie(t, edb, root)
- return trie, vals
-}
-
-func randBytes(n int) []byte {
- r := make([]byte, n)
- rand.Read(r)
- return r
-}
-
// TearDownDB is used to tear down the watcher dbs after tests
func TearDownDB(db *sqlx.DB) error {
tx, err := db.Beginx()
@@ -179,12 +239,3 @@ func TearDownDB(db *sqlx.DB) error {
}
return tx.Commit()
}
-
-// returns a cache config with unique name (groupcache names are global)
-func makeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig {
- return pgipfsethdb.CacheConfig{
- Name: t.Name(),
- Size: 3000000, // 3MB
- ExpiryDuration: time.Hour,
- }
-}