diff --git a/.github/workflows/on-pr.yml b/.github/workflows/on-pr.yml new file mode 100644 index 0000000..44a77d2 --- /dev/null +++ b/.github/workflows/on-pr.yml @@ -0,0 +1,7 @@ +name: PR actions +on: + - pull_request + +jobs: + run-tests: + uses: ./.github/workflows/test.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..7380088 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,26 @@ +name: Run tests +on: + workflow_call: + +jobs: + unit-tests: + name: Run unit tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v4 + with: + go-version-file: 'go.mod' + check-latest: true + - name: "Run DB container" + working-directory: ./test + run: | + docker compose up -d --quiet-pull + - name: "Build and run tests" + run: | + until [[ "$(docker inspect test-ipld-eth-db | jq -r '.[0].State.Status')" = 'running' ]] + do sleep 1; done & + + go build ./... + wait $! + go test -v ./... diff --git a/README.md b/README.md index 6f331b2..346b70b 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,14 @@ # ipld-eth-statedb -Implementation of the geth [vm.StateDB](https://github.com/ethereum/go-ethereum/blob/master/core/vm/interface.go#L28) on top of -[ipld-eth-db](https://github.com/cerc-io/ipld-eth-db), to allow us to plug into existing EVM functionality. Analogous to -[ipfs-ethdb](https://github.com/cerc-io/ipfs-ethdb) but at one database abstraction level higher. This allows us to -bypass the trie-traversal access pattern normally used by the EVM (and which ipfs-ethdb allows us to replicate ontop of our -Postgres IPLD blockstore in the "public.blocks" table) and access state and storage directly in the "state_cids" and -"storage_cids" tables. +This contains multiple implementations of the geth [vm.StateDB](https://github.com/ethereum/go-ethereum/blob/master/core/vm/interface.go#L28) and supporting types for different use cases. +## Package `direct_by_leaf` -Note: "IPFS" is chosen in the name of "ipfs-ethdb" as it can function through an IPFS BlockService abstraction or directly ontop of an IPLD blockstore, whereas this repository -is very tightly coupled to the schema in ipld-eth-db. +A read-only implementation which uses the schema defined in [ipld-eth-db](https://github.com/cerc-io/ipld-eth-db), to allow direct querying by state and storage node leaf key, bypassing the trie-traversal access pattern normally used by the EVM. +This operates at one abstraction level higher than [ipfs-ethdb](https://github.com/cerc-io/ipfs-ethdb), and is suitable for providing fast state reads. -The top-level package contains the implementation of the `vm.StateDB` interface that accesses state directly using the -`state_cids` and `storage_cids` tables in ipld-eth-db. The `trie_by_cid` package contains an alternative implementation -which accesses state in `ipld.blocks` through the typical trie traversal access pattern (using CIDs instead of raw -keccak256 hashes), it is used for benchmarking and for functionality which requires performing a trie traversal -(things which must collect intermediate nodes, e.g. `eth_getProof` and `eth_getSlice`). +## Package `trie_by_cid` + +A read-write implementation which uses a Postgres IPLD v0 Blockstore as the backing `ethdb.Database`. Specifically this passes v1 CIDs of Keccak-256 hashes to the database in place of plain hashes, and can be used in combination with a [ipfs-ethdb/postgres/v0](https://github.com/cerc-io/ipfs-ethdb/tree/v5/postgres/v0) `Database` instance, or an IPLD BlockService providing a v0 Blockstore. + +This implementation uses trie traversal to access state, and is capable of computing state root hashes and performing full EVM operations. It's also suitable for scenarios requiring trie traversal and access to intermediate state nodes (e.g. `eth_getProof` and `eth_getSlice` on [ipld-eth-server](https://github.com/cerc-io/ipld-eth-server)). diff --git a/direct_by_leaf/statedb_test.go b/direct_by_leaf/statedb_test.go index 03cd91b..bc945a7 100644 --- a/direct_by_leaf/statedb_test.go +++ b/direct_by_leaf/statedb_test.go @@ -22,7 +22,13 @@ import ( ) var ( - testCtx = context.Background() + testCtx = context.Background() + teardownStatements = []string{ + `TRUNCATE eth.header_cids`, + `TRUNCATE eth.state_cids`, + `TRUNCATE eth.storage_cids`, + `TRUNCATE ipld.blocks`, + } // Fixture data // block one: contract account and slot are created @@ -96,13 +102,7 @@ func TestPGXSuite(t *testing.T) { t.Cleanup(func() { tx, err := pool.Begin(testCtx) require.NoError(t, err) - statements := []string{ - `TRUNCATE eth.header_cids`, - `TRUNCATE eth.state_cids`, - `TRUNCATE eth.storage_cids`, - `TRUNCATE ipld.blocks`, - } - for _, stm := range statements { + for _, stm := range teardownStatements { _, err = tx.Exec(testCtx, stm) require.NoErrorf(t, err, "Exec(`%s`)", stm) } @@ -127,13 +127,7 @@ func TestSQLXSuite(t *testing.T) { t.Cleanup(func() { tx, err := pool.Begin() require.NoError(t, err) - statements := []string{ - `TRUNCATE eth.header_cids`, - `TRUNCATE eth.state_cids`, - `TRUNCATE eth.storage_cids`, - `TRUNCATE ipld.blocks`, - } - for _, stm := range statements { + for _, stm := range teardownStatements { _, err = tx.Exec(stm) require.NoErrorf(t, err, "Exec(`%s`)", stm) } diff --git a/go.mod b/go.mod index 6c1f715..8087592 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,11 @@ module github.com/cerc-io/ipld-eth-statedb -go 1.18 +go 1.19 require ( github.com/VictoriaMetrics/fastcache v1.6.0 github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha + github.com/davecgh/go-spew v1.1.1 github.com/ethereum/go-ethereum v1.11.5 github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d github.com/ipfs/go-cid v0.2.0 @@ -13,6 +14,7 @@ require ( github.com/jmoiron/sqlx v1.3.5 github.com/lib/pq v1.10.6 github.com/multiformats/go-multihash v0.1.0 + github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.6.0 ) @@ -21,13 +23,13 @@ require ( github.com/DataDog/zstd v1.5.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/btcsuite/btcd/btcec/v2 v2.2.0 // indirect + github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cockroachdb/errors v1.9.1 // indirect github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect github.com/cockroachdb/pebble v0.0.0-20230209160836-829675f94811 // indirect github.com/cockroachdb/redact v1.1.3 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/deckarep/golang-set/v2 v2.1.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect github.com/edsrzf/mmap-go v1.0.0 // indirect @@ -89,7 +91,6 @@ require ( github.com/segmentio/fasthash v1.0.3 // indirect github.com/shirou/gopsutil v3.21.11+incompatible // indirect github.com/shopspring/decimal v1.2.0 // indirect - github.com/sirupsen/logrus v1.9.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/status-im/keycard-go v0.2.0 // indirect github.com/stretchr/objx v0.5.0 // indirect @@ -112,4 +113,4 @@ require ( lukechampine.com/blake3 v1.1.6 // indirect ) -replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha +replace github.com/ethereum/go-ethereum v1.11.5 => github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha diff --git a/go.sum b/go.sum index fef91fa..e7e8faa 100644 --- a/go.sum +++ b/go.sum @@ -17,12 +17,13 @@ github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5 github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/btcsuite/btcd v0.22.0-beta h1:LTDpDKUM5EeOFBPM8IXpinEcmZ6FWfNZbE3lfrfdnWo= github.com/btcsuite/btcd/btcec/v2 v2.2.0 h1:fzn1qaOt32TuLjFlkzYSsBC35Q3KUjT1SwPxiMSCF5k= github.com/btcsuite/btcd/btcec/v2 v2.2.0/go.mod h1:U7MHm051Al6XmscBQ0BoNydpOTsFAn707034b5nY8zU= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 h1:KdUfX2zKommPRa+PD0sWZUyXe9w277ABlgELO7H04IM= +github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha h1:x9muuG0Z2W/UAkwAq+0SXhYG9MCP6SJlGEfFNQa6JPI= -github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.1-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek= +github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha h1:rhRmK/NeWMnQ07E4DuLb7WSh9FMotlXMPPaOrf8GJwM= +github.com/cerc-io/go-ethereum v1.11.5-statediff-5.0.3-alpha/go.mod h1:DIk2wFexjyzvyjuzSOtBEIAPRNZTnLXNbIHEyq1Igek= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha h1:I1iXTaIjbTH8ehzNXmT2waXcYBifi1yjK6FK3W3a0Pg= github.com/cerc-io/ipfs-ethdb/v5 v5.0.0-alpha/go.mod h1:EGAdV/YewEADFDDVF1k9GNwy8vNWR29Xb87sRHgMIng= github.com/cespare/cp v0.1.0 h1:SE+dxFebS7Iik5LK0tsi1k9ZCxEaFX4AjQmoyA+1dJk= diff --git a/internal/util.go b/internal/util.go index a19df0a..634c18a 100644 --- a/internal/util.go +++ b/internal/util.go @@ -1,6 +1,10 @@ package internal import ( + "testing" + "time" + + pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" "github.com/ipfs/go-cid" "github.com/multiformats/go-multihash" ) @@ -12,3 +16,12 @@ func Keccak256ToCid(codec uint64, h []byte) (cid.Cid, error) { } return cid.NewCidV1(codec, buf), nil } + +// returns a cache config with unique name (groupcache names are global) +func MakeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig { + return pgipfsethdb.CacheConfig{ + Name: t.Name(), + Size: 3000000, // 3MB + ExpiryDuration: time.Hour, + } +} diff --git a/test/compose.yml b/test/compose.yml new file mode 100644 index 0000000..d0ba3a2 --- /dev/null +++ b/test/compose.yml @@ -0,0 +1,26 @@ +# Containers to run backing DB for unit testing + +services: + migrations: + restart: on-failure + depends_on: + - ipld-eth-db + image: git.vdb.to/cerc-io/ipld-eth-db/ipld-eth-db:v5.0.2-alpha + environment: + DATABASE_USER: "vdbm" + DATABASE_NAME: "cerc_testing" + DATABASE_PASSWORD: "password" + DATABASE_HOSTNAME: "ipld-eth-db" + DATABASE_PORT: 5432 + + ipld-eth-db: + container_name: test-ipld-eth-db + image: timescale/timescaledb:latest-pg14 + restart: always + command: ["postgres", "-c", "log_statement=all"] + environment: + POSTGRES_USER: "vdbm" + POSTGRES_DB: "cerc_testing" + POSTGRES_PASSWORD: "password" + ports: + - "127.0.0.1:8077:5432" diff --git a/trie_by_cid/doc.go b/trie_by_cid/doc.go new file mode 100644 index 0000000..4b9b440 --- /dev/null +++ b/trie_by_cid/doc.go @@ -0,0 +1,3 @@ +// This package is a near complete copy of go-ethereum/trie and go-ethereum/core/state, modified to use +// a v0 IPFS blockstore as the backing DB, i.e. DB values are indexed by CID rather than hash. +package trie_by_cid diff --git a/trie_by_cid/helper/statediff_helper.go b/trie_by_cid/helper/statediff_helper.go index b03abf6..030c2b3 100644 --- a/trie_by_cid/helper/statediff_helper.go +++ b/trie_by_cid/helper/statediff_helper.go @@ -20,7 +20,10 @@ var ( mockTD = big.NewInt(1) ) -func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { +// IndexStateDiff indexes a single statediff. +// - uses TestChainConfig +// - block hash/number are left as zero +func IndexStateDiff(dbConfig postgres.Config, stateCache state.Database, rootA, rootB common.Hash) error { _, indexer, err := indexer.NewStateDiffIndexer( context.Background(), ChainConfig, node.Info{}, dbConfig) if err != nil { @@ -28,11 +31,10 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root } defer indexer.Close() // fixme: hangs when using PGX driver - // generating statediff payload for block, and transform the data into Postgres builder := statediff.NewBuilder(stateCache) block := types.NewBlock(&types.Header{Root: rootB}, nil, nil, nil, NewHasher()) - // todo: use dummy block hashes to just produce trie structure for testing + // uses zero block hash/number, we only need the trie structure here args := statediff.Args{ OldStateRoot: rootA, NewStateRoot: rootB, @@ -45,12 +47,7 @@ func IndexChain(dbConfig postgres.Config, stateCache state.Database, rootA, root if err != nil { return err } - // for _, node := range diff.Nodes { - // err := indexer.PushStateNode(tx, node, block.Hash().String()) - // if err != nil { - // return err - // } - // } + // we don't need to index diff.Nodes since we are just interested in the trie for _, ipld := range diff.IPLDs { if err := indexer.PushIPLD(tx, ipld); err != nil { return err diff --git a/trie_by_cid/state/access_list.go b/trie_by_cid/state/access_list.go new file mode 100644 index 0000000..4194691 --- /dev/null +++ b/trie_by_cid/state/access_list.go @@ -0,0 +1,136 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "github.com/ethereum/go-ethereum/common" +) + +type accessList struct { + addresses map[common.Address]int + slots []map[common.Hash]struct{} +} + +// ContainsAddress returns true if the address is in the access list. +func (al *accessList) ContainsAddress(address common.Address) bool { + _, ok := al.addresses[address] + return ok +} + +// Contains checks if a slot within an account is present in the access list, returning +// separate flags for the presence of the account and the slot respectively. +func (al *accessList) Contains(address common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { + idx, ok := al.addresses[address] + if !ok { + // no such address (and hence zero slots) + return false, false + } + if idx == -1 { + // address yes, but no slots + return true, false + } + _, slotPresent = al.slots[idx][slot] + return true, slotPresent +} + +// newAccessList creates a new accessList. +func newAccessList() *accessList { + return &accessList{ + addresses: make(map[common.Address]int), + } +} + +// Copy creates an independent copy of an accessList. +func (a *accessList) Copy() *accessList { + cp := newAccessList() + for k, v := range a.addresses { + cp.addresses[k] = v + } + cp.slots = make([]map[common.Hash]struct{}, len(a.slots)) + for i, slotMap := range a.slots { + newSlotmap := make(map[common.Hash]struct{}, len(slotMap)) + for k := range slotMap { + newSlotmap[k] = struct{}{} + } + cp.slots[i] = newSlotmap + } + return cp +} + +// AddAddress adds an address to the access list, and returns 'true' if the operation +// caused a change (addr was not previously in the list). +func (al *accessList) AddAddress(address common.Address) bool { + if _, present := al.addresses[address]; present { + return false + } + al.addresses[address] = -1 + return true +} + +// AddSlot adds the specified (addr, slot) combo to the access list. +// Return values are: +// - address added +// - slot added +// For any 'true' value returned, a corresponding journal entry must be made. +func (al *accessList) AddSlot(address common.Address, slot common.Hash) (addrChange bool, slotChange bool) { + idx, addrPresent := al.addresses[address] + if !addrPresent || idx == -1 { + // Address not present, or addr present but no slots there + al.addresses[address] = len(al.slots) + slotmap := map[common.Hash]struct{}{slot: {}} + al.slots = append(al.slots, slotmap) + return !addrPresent, true + } + // There is already an (address,slot) mapping + slotmap := al.slots[idx] + if _, ok := slotmap[slot]; !ok { + slotmap[slot] = struct{}{} + // Journal add slot change + return false, true + } + // No changes required + return false, false +} + +// DeleteSlot removes an (address, slot)-tuple from the access list. +// This operation needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) { + idx, addrOk := al.addresses[address] + // There are two ways this can fail + if !addrOk { + panic("reverting slot change, address not present in list") + } + slotmap := al.slots[idx] + delete(slotmap, slot) + // If that was the last (first) slot, remove it + // Since additions and rollbacks are always performed in order, + // we can delete the item without worrying about screwing up later indices + if len(slotmap) == 0 { + al.slots = al.slots[:idx] + al.addresses[address] = -1 + } +} + +// DeleteAddress removes an address from the access list. This operation +// needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteAddress(address common.Address) { + delete(al.addresses, address) +} diff --git a/trie_by_cid/state/database.go b/trie_by_cid/state/database.go index 30adbb7..e6920da 100644 --- a/trie_by_cid/state/database.go +++ b/trie_by_cid/state/database.go @@ -1,14 +1,30 @@ +// Copyright 2017 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + package state import ( "errors" + "fmt" - "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/lru" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/statediff/indexer/ipld" - lru "github.com/hashicorp/golang-lru" "github.com/cerc-io/ipld-eth-statedb/internal" "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" @@ -28,7 +44,10 @@ type Database interface { OpenTrie(root common.Hash) (Trie, error) // OpenStorageTrie opens the storage trie of an account. - OpenStorageTrie(addrHash, root common.Hash) (Trie, error) + OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) + + // CopyTrie returns an independent copy of the given trie. + CopyTrie(Trie) Trie // ContractCode retrieves a particular contract's code. ContractCode(codeHash common.Hash) ([]byte, error) @@ -36,17 +55,79 @@ type Database interface { // ContractCodeSize retrieves a particular contracts code's size. ContractCodeSize(codeHash common.Hash) (int, error) + // DiskDB returns the underlying key-value disk database. + DiskDB() ethdb.KeyValueStore + // TrieDB retrieves the low level trie database used for data storage. TrieDB() *trie.Database } // Trie is a Ethereum Merkle Patricia trie. type Trie interface { + // GetKey returns the sha3 preimage of a hashed key that was previously used + // to store a value. + // + // TODO(fjl): remove this when StateTrie is removed + GetKey([]byte) []byte + + // TryGet returns the value for key stored in the trie. The value bytes must + // not be modified by the caller. If a node was not found in the database, a + // trie.MissingNodeError is returned. TryGet(key []byte) ([]byte, error) + + // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not + // possible to use keybyte-encoding as the path might contain odd nibbles. TryGetNode(path []byte) ([]byte, int, error) - TryGetAccount(key []byte) (*types.StateAccount, error) + + // TryGetAccount abstracts an account read from the trie. It retrieves the + // account blob from the trie with provided account address and decodes it + // with associated decoding algorithm. If the specified account is not in + // the trie, nil will be returned. If the trie is corrupted(e.g. some nodes + // are missing or the account blob is incorrect for decoding), an error will + // be returned. + TryGetAccount(address common.Address) (*types.StateAccount, error) + + // TryUpdate associates key with value in the trie. If value has length zero, any + // existing value is deleted from the trie. The value bytes must not be modified + // by the caller while they are stored in the trie. If a node was not found in the + // database, a trie.MissingNodeError is returned. + TryUpdate(key, value []byte) error + + // TryUpdateAccount abstracts an account write to the trie. It encodes the + // provided account object with associated algorithm and then updates it + // in the trie with provided address. + TryUpdateAccount(address common.Address, account *types.StateAccount) error + + // TryDelete removes any existing value for key from the trie. If a node was not + // found in the database, a trie.MissingNodeError is returned. + TryDelete(key []byte) error + + // TryDeleteAccount abstracts an account deletion from the trie. + TryDeleteAccount(address common.Address) error + + // Hash returns the root hash of the trie. It does not write to the database and + // can be used even if the trie doesn't have one. Hash() common.Hash + + // Commit collects all dirty nodes in the trie and replace them with the + // corresponding node hash. All collected nodes(including dirty leaves if + // collectLeaf is true) will be encapsulated into a nodeset for return. + // The returned nodeset can be nil if the trie is clean(nothing to commit). + // Once the trie is committed, it's not usable anymore. A new trie must + // be created with new root and updated trie database for following usage + Commit(collectLeaf bool) (common.Hash, *trie.NodeSet) + + // NodeIterator returns an iterator that returns nodes of the trie. Iteration + // starts at the key after the given start key. NodeIterator(startKey []byte) trie.NodeIterator + + // Prove constructs a Merkle proof for key. The result contains all encoded nodes + // on the path to the value at key. The value itself is also included in the last + // node and can be retrieved by verifying the proof. + // + // If the trie does not contain a value for key, the returned proof contains all + // nodes of the longest existing prefix of the key (at least the root), ending + // with the node that proves the absence of the key. Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error } @@ -61,23 +142,34 @@ func NewDatabase(db ethdb.Database) Database { // is safe for concurrent use and retains a lot of collapsed RLP trie nodes in a // large memory cache. func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { - csc, _ := lru.New(codeSizeCacheSize) return &cachingDB{ - db: trie.NewDatabaseWithConfig(db, config), - codeSizeCache: csc, - codeCache: fastcache.New(codeCacheSize), + disk: db, + codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), + codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), + triedb: trie.NewDatabaseWithConfig(db, config), + } +} + +// NewDatabaseWithNodeDB creates a state database with an already initialized node database. +func NewDatabaseWithNodeDB(db ethdb.Database, triedb *trie.Database) Database { + return &cachingDB{ + disk: db, + codeSizeCache: lru.NewCache[common.Hash, int](codeSizeCacheSize), + codeCache: lru.NewSizeConstrainedCache[common.Hash, []byte](codeCacheSize), + triedb: triedb, } } type cachingDB struct { - db *trie.Database - codeSizeCache *lru.Cache - codeCache *fastcache.Cache + disk ethdb.KeyValueStore + codeSizeCache *lru.Cache[common.Hash, int] + codeCache *lru.SizeConstrainedCache[common.Hash, []byte] + triedb *trie.Database } // OpenTrie opens the main account trie at a specific root hash. func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { - tr, err := trie.NewStateTrie(common.Hash{}, root, db.db) + tr, err := trie.NewStateTrie(trie.StateTrieID(root), db.triedb, trie.StateTrieCodec) if err != nil { return nil, err } @@ -85,29 +177,40 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { } // OpenStorageTrie opens the storage trie of an account. -func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { - tr, err := trie.NewStorageTrie(addrHash, root, db.db) +func (db *cachingDB) OpenStorageTrie(stateRoot common.Hash, addrHash, root common.Hash) (Trie, error) { + tr, err := trie.NewStateTrie(trie.StorageTrieID(stateRoot, addrHash, root), db.triedb, trie.StorageTrieCodec) if err != nil { return nil, err } return tr, nil } +// CopyTrie returns an independent copy of the given trie. +func (db *cachingDB) CopyTrie(t Trie) Trie { + switch t := t.(type) { + case *trie.StateTrie: + return t.Copy() + default: + panic(fmt.Errorf("unknown trie type %T", t)) + } +} + // ContractCode retrieves a particular contract's code. func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) { - if code := db.codeCache.Get(nil, codeHash.Bytes()); len(code) > 0 { + code, _ := db.codeCache.Get(codeHash) + if len(code) > 0 { return code, nil } - codeCID, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes()) + cid, err := internal.Keccak256ToCid(ipld.RawBinary, codeHash.Bytes()) if err != nil { return nil, err } - code, err := db.db.DiskDB().Get(codeCID.Bytes()) + code, err = db.disk.Get(cid.Bytes()) if err != nil { return nil, err } if len(code) > 0 { - db.codeCache.Set(codeHash.Bytes(), code) + db.codeCache.Add(codeHash, code) db.codeSizeCache.Add(codeHash, len(code)) return code, nil } @@ -117,13 +220,18 @@ func (db *cachingDB) ContractCode(codeHash common.Hash) ([]byte, error) { // ContractCodeSize retrieves a particular contracts code's size. func (db *cachingDB) ContractCodeSize(codeHash common.Hash) (int, error) { if cached, ok := db.codeSizeCache.Get(codeHash); ok { - return cached.(int), nil + return cached, nil } code, err := db.ContractCode(codeHash) return len(code), err } +// DiskDB returns the underlying key-value disk database. +func (db *cachingDB) DiskDB() ethdb.KeyValueStore { + return db.disk +} + // TrieDB retrieves any intermediate trie-node caching layer. func (db *cachingDB) TrieDB() *trie.Database { - return db.db + return db.triedb } diff --git a/trie_by_cid/state/journal.go b/trie_by_cid/state/journal.go new file mode 100644 index 0000000..1722fb4 --- /dev/null +++ b/trie_by_cid/state/journal.go @@ -0,0 +1,282 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "math/big" + + "github.com/ethereum/go-ethereum/common" +) + +// journalEntry is a modification entry in the state change journal that can be +// reverted on demand. +type journalEntry interface { + // revert undoes the changes introduced by this journal entry. + revert(*StateDB) + + // dirtied returns the Ethereum address modified by this journal entry. + dirtied() *common.Address +} + +// journal contains the list of state modifications applied since the last state +// commit. These are tracked to be able to be reverted in the case of an execution +// exception or request for reversal. +type journal struct { + entries []journalEntry // Current changes tracked by the journal + dirties map[common.Address]int // Dirty accounts and the number of changes +} + +// newJournal creates a new initialized journal. +func newJournal() *journal { + return &journal{ + dirties: make(map[common.Address]int), + } +} + +// append inserts a new modification entry to the end of the change journal. +func (j *journal) append(entry journalEntry) { + j.entries = append(j.entries, entry) + if addr := entry.dirtied(); addr != nil { + j.dirties[*addr]++ + } +} + +// revert undoes a batch of journalled modifications along with any reverted +// dirty handling too. +func (j *journal) revert(statedb *StateDB, snapshot int) { + for i := len(j.entries) - 1; i >= snapshot; i-- { + // Undo the changes made by the operation + j.entries[i].revert(statedb) + + // Drop any dirty tracking induced by the change + if addr := j.entries[i].dirtied(); addr != nil { + if j.dirties[*addr]--; j.dirties[*addr] == 0 { + delete(j.dirties, *addr) + } + } + } + j.entries = j.entries[:snapshot] +} + +// dirty explicitly sets an address to dirty, even if the change entries would +// otherwise suggest it as clean. This method is an ugly hack to handle the RIPEMD +// precompile consensus exception. +func (j *journal) dirty(addr common.Address) { + j.dirties[addr]++ +} + +// length returns the current number of entries in the journal. +func (j *journal) length() int { + return len(j.entries) +} + +type ( + // Changes to the account trie. + createObjectChange struct { + account *common.Address + } + resetObjectChange struct { + prev *stateObject + prevdestruct bool + } + suicideChange struct { + account *common.Address + prev bool // whether account had already suicided + prevbalance *big.Int + } + + // Changes to individual accounts. + balanceChange struct { + account *common.Address + prev *big.Int + } + nonceChange struct { + account *common.Address + prev uint64 + } + storageChange struct { + account *common.Address + key, prevalue common.Hash + } + codeChange struct { + account *common.Address + prevcode, prevhash []byte + } + + // Changes to other state values. + refundChange struct { + prev uint64 + } + addLogChange struct { + txhash common.Hash + } + addPreimageChange struct { + hash common.Hash + } + touchChange struct { + account *common.Address + } + // Changes to the access list + accessListAddAccountChange struct { + address *common.Address + } + accessListAddSlotChange struct { + address *common.Address + slot *common.Hash + } + + transientStorageChange struct { + account *common.Address + key, prevalue common.Hash + } +) + +func (ch createObjectChange) revert(s *StateDB) { + delete(s.stateObjects, *ch.account) + delete(s.stateObjectsDirty, *ch.account) +} + +func (ch createObjectChange) dirtied() *common.Address { + return ch.account +} + +func (ch resetObjectChange) revert(s *StateDB) { + s.setStateObject(ch.prev) + if !ch.prevdestruct { + delete(s.stateObjectsDestruct, ch.prev.address) + } +} + +func (ch resetObjectChange) dirtied() *common.Address { + return nil +} + +func (ch suicideChange) revert(s *StateDB) { + obj := s.getStateObject(*ch.account) + if obj != nil { + obj.suicided = ch.prev + obj.setBalance(ch.prevbalance) + } +} + +func (ch suicideChange) dirtied() *common.Address { + return ch.account +} + +var ripemd = common.HexToAddress("0000000000000000000000000000000000000003") + +func (ch touchChange) revert(s *StateDB) { +} + +func (ch touchChange) dirtied() *common.Address { + return ch.account +} + +func (ch balanceChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setBalance(ch.prev) +} + +func (ch balanceChange) dirtied() *common.Address { + return ch.account +} + +func (ch nonceChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setNonce(ch.prev) +} + +func (ch nonceChange) dirtied() *common.Address { + return ch.account +} + +func (ch codeChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) +} + +func (ch codeChange) dirtied() *common.Address { + return ch.account +} + +func (ch storageChange) revert(s *StateDB) { + s.getStateObject(*ch.account).setState(ch.key, ch.prevalue) +} + +func (ch storageChange) dirtied() *common.Address { + return ch.account +} + +func (ch transientStorageChange) revert(s *StateDB) { + s.setTransientState(*ch.account, ch.key, ch.prevalue) +} + +func (ch transientStorageChange) dirtied() *common.Address { + return nil +} + +func (ch refundChange) revert(s *StateDB) { + s.refund = ch.prev +} + +func (ch refundChange) dirtied() *common.Address { + return nil +} + +func (ch addLogChange) revert(s *StateDB) { + logs := s.logs[ch.txhash] + if len(logs) == 1 { + delete(s.logs, ch.txhash) + } else { + s.logs[ch.txhash] = logs[:len(logs)-1] + } + s.logSize-- +} + +func (ch addLogChange) dirtied() *common.Address { + return nil +} + +func (ch addPreimageChange) revert(s *StateDB) { + delete(s.preimages, ch.hash) +} + +func (ch addPreimageChange) dirtied() *common.Address { + return nil +} + +func (ch accessListAddAccountChange) revert(s *StateDB) { + /* + One important invariant here, is that whenever a (addr, slot) is added, if the + addr is not already present, the add causes two journal entries: + - one for the address, + - one for the (address,slot) + Therefore, when unrolling the change, we can always blindly delete the + (addr) at this point, since no storage adds can remain when come upon + a single (addr) change. + */ + s.accessList.DeleteAddress(*ch.address) +} + +func (ch accessListAddAccountChange) dirtied() *common.Address { + return nil +} + +func (ch accessListAddSlotChange) revert(s *StateDB) { + s.accessList.DeleteSlot(*ch.address, *ch.slot) +} + +func (ch accessListAddSlotChange) dirtied() *common.Address { + return nil +} diff --git a/trie_by_cid/state/metrics.go b/trie_by_cid/state/metrics.go new file mode 100644 index 0000000..e702ef3 --- /dev/null +++ b/trie_by_cid/state/metrics.go @@ -0,0 +1,30 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import "github.com/ethereum/go-ethereum/metrics" + +var ( + accountUpdatedMeter = metrics.NewRegisteredMeter("state/update/account", nil) + storageUpdatedMeter = metrics.NewRegisteredMeter("state/update/storage", nil) + accountDeletedMeter = metrics.NewRegisteredMeter("state/delete/account", nil) + storageDeletedMeter = metrics.NewRegisteredMeter("state/delete/storage", nil) + accountTrieUpdatedMeter = metrics.NewRegisteredMeter("state/update/accountnodes", nil) + storageTriesUpdatedMeter = metrics.NewRegisteredMeter("state/update/storagenodes", nil) + accountTrieDeletedMeter = metrics.NewRegisteredMeter("state/delete/accountnodes", nil) + storageTriesDeletedMeter = metrics.NewRegisteredMeter("state/delete/storagenodes", nil) +) diff --git a/trie_by_cid/state/state_object.go b/trie_by_cid/state/state_object.go index 6b53c67..e1afcb3 100644 --- a/trie_by_cid/state/state_object.go +++ b/trie_by_cid/state/state_object.go @@ -1,8 +1,25 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + package state import ( "bytes" "fmt" + "io" "math/big" "time" @@ -11,15 +28,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" -) - -var ( - // emptyRoot is the known root hash of an empty trie. - // this is calculated as: emptyRoot = crypto.Keccak256(rlp.Encode([][]byte{})) - // that is, the keccak356 hash of the rlp encoding of an empty trie node (empty byte slice array) - emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - // emptyCodeHash is the CodeHash for an EOA, for an account without contract code deployed - emptyCodeHash = crypto.Keccak256(nil) + // "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) type Code []byte @@ -34,7 +43,6 @@ func (s Storage) String() (str string) { for key, value := range s { str += fmt.Sprintf("%X : %X\n", key, value) } - return } @@ -43,32 +51,39 @@ func (s Storage) Copy() Storage { for key, value := range s { cpy[key] = value } - return cpy } -// stateObject represents an Ethereum account which is being accessed. +// stateObject represents an Ethereum account which is being modified. // // The usage pattern is as follows: // First you need to obtain a state object. -// Account values can be accessed through the object. +// Account values can be accessed and modified through the object. type stateObject struct { address common.Address addrHash common.Hash // hash of ethereum address of the account data types.StateAccount db *StateDB - // Caches. + // Write caches. trie Trie // storage trie, which becomes non-nil on first access code Code // contract bytecode, which gets set when code is loaded - originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction - fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. + originStorage Storage // Storage cache of original entries to dedup rewrites, reset for every transaction + pendingStorage Storage // Storage entries that need to be flushed to disk, at the end of an entire block + dirtyStorage Storage // Storage entries that have been modified in the current transaction execution + + // Cache flags. + // When an object is marked suicided it will be deleted from the trie + // during the "update" phase of the state transition. + dirtyCode bool // true if the code was updated + suicided bool + deleted bool } // empty returns whether the account is considered empty. func (s *stateObject) empty() bool { - return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, emptyCodeHash) + return s.data.Nonce == 0 && s.data.Balance.Sign() == 0 && bytes.Equal(s.data.CodeHash, types.EmptyCodeHash.Bytes()) } // newObject creates a state object. @@ -77,71 +92,128 @@ func newObject(db *StateDB, address common.Address, data types.StateAccount) *st data.Balance = new(big.Int) } if data.CodeHash == nil { - data.CodeHash = emptyCodeHash + data.CodeHash = types.EmptyCodeHash.Bytes() } if data.Root == (common.Hash{}) { - data.Root = emptyRoot + data.Root = types.EmptyRootHash } return &stateObject{ - db: db, - address: address, - addrHash: crypto.Keccak256Hash(address[:]), - data: data, - originStorage: make(Storage), + db: db, + address: address, + addrHash: crypto.Keccak256Hash(address[:]), + data: data, + originStorage: make(Storage), + pendingStorage: make(Storage), + dirtyStorage: make(Storage), } } -// setError remembers the first non-nil error it is called with. -func (s *stateObject) setError(err error) { - s.db.setError(err) +// EncodeRLP implements rlp.Encoder. +func (s *stateObject) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, &s.data) } -func (s *stateObject) getTrie(db Database) Trie { +func (s *stateObject) markSuicided() { + s.suicided = true +} + +func (s *stateObject) touch() { + s.db.journal.append(touchChange{ + account: &s.address, + }) + if s.address == ripemd { + // Explicitly put it in the dirty-cache, which is otherwise generated from + // flattened journals. + s.db.journal.dirty(s.address) + } +} + +// getTrie returns the associated storage trie. The trie will be opened +// if it's not loaded previously. An error will be returned if trie can't +// be loaded. +func (s *stateObject) getTrie(db Database) (Trie, error) { if s.trie == nil { - // // Try fetching from prefetcher first - // // We don't prefetch empty tries - // if s.data.Root != emptyRoot && s.db.prefetcher != nil { - // // When the miner is creating the pending state, there is no - // // prefetcher - // s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) - // } + // Try fetching from prefetcher first + // We don't prefetch empty tries + if s.data.Root != types.EmptyRootHash && s.db.prefetcher != nil { + // When the miner is creating the pending state, there is no + // prefetcher + s.trie = s.db.prefetcher.trie(s.addrHash, s.data.Root) + } if s.trie == nil { - var err error - s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) + tr, err := db.OpenStorageTrie(s.db.originalRoot, s.addrHash, s.data.Root) if err != nil { - s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) - s.setError(fmt.Errorf("can't create storage trie: %w", err)) + return nil, err } + s.trie = tr } } - return s.trie + return s.trie, nil +} + +// GetState retrieves a value from the account storage trie. +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + // If we have a dirty value for this state entry, return it + value, dirty := s.dirtyStorage[key] + if dirty { + return value + } + // Otherwise return the entry's original value + return s.GetCommittedState(db, key) } // GetCommittedState retrieves a value from the committed account storage trie. -func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { - // If the fake storage is set, only lookup the state here(in the debugging mode) - if s.fakeStorage != nil { - return s.fakeStorage[key] +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + // If we have a pending write or clean cached, return that + if value, pending := s.pendingStorage[key]; pending { + return value } - // If we have a cached value, return that if value, cached := s.originStorage[key]; cached { return value } - // If no live objects are available, load from the database. - start := time.Now() - enc, err := s.getTrie(db).TryGet(key.Bytes()) - if metrics.EnabledExpensive { - s.db.StorageReads += time.Since(start) - } - if err != nil { - s.setError(err) + // If the object was destructed in *this* block (and potentially resurrected), + // the storage has been cleared out, and we should *not* consult the previous + // database about any storage values. The only possible alternatives are: + // 1) resurrect happened, and new slot values were set -- those should + // have been handles via pendingStorage above. + // 2) we don't have new values, and can deliver empty response back + if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed { return common.Hash{} } + // If no live objects are available, attempt to use snapshots + var ( + enc []byte + err error + ) + if s.db.snap != nil { + start := time.Now() + enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key.Bytes())) + if metrics.EnabledExpensive { + s.db.SnapshotStorageReads += time.Since(start) + } + } + // If the snapshot is unavailable or reading from it fails, load from the database. + if s.db.snap == nil || err != nil { + start := time.Now() + tr, err := s.getTrie(db) + if err != nil { + s.db.setError(err) + return common.Hash{} + } + enc, err = tr.TryGet(key.Bytes()) + if metrics.EnabledExpensive { + s.db.StorageReads += time.Since(start) + } + if err != nil { + s.db.setError(err) + return common.Hash{} + } + } var value common.Hash if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { - s.setError(err) + s.db.setError(err) } value.SetBytes(content) } @@ -149,6 +221,182 @@ func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { return value } +// SetState updates a value in account storage. +func (s *stateObject) SetState(db Database, key, value common.Hash) { + // If the new value is the same as old, don't set + prev := s.GetState(db, key) + if prev == value { + return + } + // New value is different, update and journal the change + s.db.journal.append(storageChange{ + account: &s.address, + key: key, + prevalue: prev, + }) + s.setState(key, value) +} + +func (s *stateObject) setState(key, value common.Hash) { + s.dirtyStorage[key] = value +} + +// finalise moves all dirty storage slots into the pending area to be hashed or +// committed later. It is invoked at the end of every transaction. +func (s *stateObject) finalise(prefetch bool) { + slotsToPrefetch := make([][]byte, 0, len(s.dirtyStorage)) + for key, value := range s.dirtyStorage { + s.pendingStorage[key] = value + if value != s.originStorage[key] { + slotsToPrefetch = append(slotsToPrefetch, common.CopyBytes(key[:])) // Copy needed for closure + } + } + if s.db.prefetcher != nil && prefetch && len(slotsToPrefetch) > 0 && s.data.Root != types.EmptyRootHash { + s.db.prefetcher.prefetch(s.addrHash, s.data.Root, slotsToPrefetch) + } + if len(s.dirtyStorage) > 0 { + s.dirtyStorage = make(Storage) + } +} + +// updateTrie writes cached storage modifications into the object's storage trie. +// It will return nil if the trie has not been loaded and no changes have been +// made. An error will be returned if the trie can't be loaded/updated correctly. +func (s *stateObject) updateTrie(db Database) (Trie, error) { + // Make sure all dirty slots are finalized into the pending storage area + s.finalise(false) // Don't prefetch anymore, pull directly if need be + if len(s.pendingStorage) == 0 { + return s.trie, nil + } + // Track the amount of time wasted on updating the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now()) + } + // The snapshot storage map for the object + var ( + storage map[common.Hash][]byte + hasher = s.db.hasher + ) + tr, err := s.getTrie(db) + if err != nil { + s.db.setError(err) + return nil, err + } + // Insert all the pending updates into the trie + usedStorage := make([][]byte, 0, len(s.pendingStorage)) + for key, value := range s.pendingStorage { + // Skip noop changes, persist actual changes + if value == s.originStorage[key] { + continue + } + s.originStorage[key] = value + + var v []byte + if (value == common.Hash{}) { + if err := tr.TryDelete(key[:]); err != nil { + s.db.setError(err) + return nil, err + } + s.db.StorageDeleted += 1 + } else { + // Encoding []byte cannot fail, ok to ignore the error. + v, _ = rlp.EncodeToBytes(common.TrimLeftZeroes(value[:])) + if err := tr.TryUpdate(key[:], v); err != nil { + s.db.setError(err) + return nil, err + } + s.db.StorageUpdated += 1 + } + // If state snapshotting is active, cache the data til commit + if s.db.snap != nil { + if storage == nil { + // Retrieve the old storage map, if available, create a new one otherwise + if storage = s.db.snapStorage[s.addrHash]; storage == nil { + storage = make(map[common.Hash][]byte) + s.db.snapStorage[s.addrHash] = storage + } + } + storage[crypto.HashData(hasher, key[:])] = v // v will be nil if it's deleted + } + usedStorage = append(usedStorage, common.CopyBytes(key[:])) // Copy needed for closure + } + if s.db.prefetcher != nil { + s.db.prefetcher.used(s.addrHash, s.data.Root, usedStorage) + } + if len(s.pendingStorage) > 0 { + s.pendingStorage = make(Storage) + } + return tr, nil +} + +// UpdateRoot sets the trie root to the current root hash of. An error +// will be returned if trie root hash is not computed correctly. +func (s *stateObject) updateRoot(db Database) { + tr, err := s.updateTrie(db) + if err != nil { + return + } + // If nothing changed, don't bother with hashing anything + if tr == nil { + return + } + // Track the amount of time wasted on hashing the storage trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.db.StorageHashes += time.Since(start) }(time.Now()) + } + s.data.Root = tr.Hash() +} + +// AddBalance adds amount to s's balance. +// It is used to add funds to the destination account of a transfer. +func (s *stateObject) AddBalance(amount *big.Int) { + // EIP161: We must check emptiness for the objects such that the account + // clearing (0,0,0 objects) can take effect. + if amount.Sign() == 0 { + if s.empty() { + s.touch() + } + return + } + s.SetBalance(new(big.Int).Add(s.Balance(), amount)) +} + +// SubBalance removes amount from s's balance. +// It is used to remove funds from the origin account of a transfer. +func (s *stateObject) SubBalance(amount *big.Int) { + if amount.Sign() == 0 { + return + } + s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) +} + +func (s *stateObject) SetBalance(amount *big.Int) { + s.db.journal.append(balanceChange{ + account: &s.address, + prev: new(big.Int).Set(s.data.Balance), + }) + s.setBalance(amount) +} + +func (s *stateObject) setBalance(amount *big.Int) { + s.data.Balance = amount +} + +func (s *stateObject) deepCopy(db *StateDB) *stateObject { + stateObject := newObject(db, s.address, s.data) + if s.trie != nil { + stateObject.trie = db.db.CopyTrie(s.trie) + } + stateObject.code = s.code + stateObject.dirtyStorage = s.dirtyStorage.Copy() + stateObject.originStorage = s.originStorage.Copy() + stateObject.pendingStorage = s.pendingStorage.Copy() + stateObject.suicided = s.suicided + stateObject.dirtyCode = s.dirtyCode + stateObject.deleted = s.deleted + return stateObject +} + // // Attribute accessors // @@ -163,12 +411,12 @@ func (s *stateObject) Code(db Database) []byte { if s.code != nil { return s.code } - if bytes.Equal(s.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return nil } code, err := db.ContractCode(common.BytesToHash(s.CodeHash())) if err != nil { - s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) + s.db.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } s.code = code return code @@ -181,16 +429,44 @@ func (s *stateObject) CodeSize(db Database) int { if s.code != nil { return len(s.code) } - if bytes.Equal(s.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), types.EmptyCodeHash.Bytes()) { return 0 } size, err := db.ContractCodeSize(common.BytesToHash(s.CodeHash())) if err != nil { - s.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) + s.db.setError(fmt.Errorf("can't load code size %x: %v", s.CodeHash(), err)) } return size } +func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := s.Code(s.db.db) + s.db.journal.append(codeChange{ + account: &s.address, + prevhash: s.CodeHash(), + prevcode: prevcode, + }) + s.setCode(codeHash, code) +} + +func (s *stateObject) setCode(codeHash common.Hash, code []byte) { + s.code = code + s.data.CodeHash = codeHash[:] + s.dirtyCode = true +} + +func (s *stateObject) SetNonce(nonce uint64) { + s.db.journal.append(nonceChange{ + account: &s.address, + prev: s.data.Nonce, + }) + s.setNonce(nonce) +} + +func (s *stateObject) setNonce(nonce uint64) { + s.data.Nonce = nonce +} + func (s *stateObject) CodeHash() []byte { return s.data.CodeHash } @@ -202,10 +478,3 @@ func (s *stateObject) Balance() *big.Int { func (s *stateObject) Nonce() uint64 { return s.data.Nonce } - -// Never called, but must be present to allow stateObject to be used -// as a vm.Account interface that also satisfies the vm.ContractRef -// interface. Interfaces are awesome. -func (s *stateObject) Value() *big.Int { - panic("Value on stateObject should never be called") -} diff --git a/trie_by_cid/state/state_object_test.go b/trie_by_cid/state/state_object_test.go new file mode 100644 index 0000000..42fd778 --- /dev/null +++ b/trie_by_cid/state/state_object_test.go @@ -0,0 +1,46 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" +) + +func BenchmarkCutOriginal(b *testing.B) { + value := common.HexToHash("0x01") + for i := 0; i < b.N; i++ { + bytes.TrimLeft(value[:], "\x00") + } +} + +func BenchmarkCutsetterFn(b *testing.B) { + value := common.HexToHash("0x01") + cutSetFn := func(r rune) bool { return r == 0 } + for i := 0; i < b.N; i++ { + bytes.TrimLeftFunc(value[:], cutSetFn) + } +} + +func BenchmarkCutCustomTrim(b *testing.B) { + value := common.HexToHash("0x01") + for i := 0; i < b.N; i++ { + common.TrimLeftZeroes(value[:]) + } +} diff --git a/trie_by_cid/state/state_test.go b/trie_by_cid/state/state_test.go new file mode 100644 index 0000000..6b3fa85 --- /dev/null +++ b/trie_by_cid/state/state_test.go @@ -0,0 +1,219 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "context" + "math/big" + "testing" + + pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" + + "github.com/cerc-io/ipld-eth-statedb/internal" +) + +var ( + testCtx = context.Background() + testConfig, _ = postgres.DefaultConfig.WithEnv() + teardownStatements = []string{`TRUNCATE ipld.blocks`} +) + +type stateTest struct { + db ethdb.Database + state *StateDB +} + +func newStateTest(t *testing.T) *stateTest { + pool, err := postgres.ConnectSQLX(testCtx, testConfig) + if err != nil { + t.Fatal(err) + } + db := pgipfsethdb.NewDatabase(pool, internal.MakeCacheConfig(t)) + sdb, err := New(common.Hash{}, NewDatabase(db), nil) + if err != nil { + t.Fatal(err) + } + return &stateTest{db: db, state: sdb} +} + +func TestNull(t *testing.T) { + s := newStateTest(t) + address := common.HexToAddress("0x823140710bf13990e4500136726d8b55") + s.state.CreateAccount(address) + //value := common.FromHex("0x823140710bf13990e4500136726d8b55") + var value common.Hash + + s.state.SetState(address, common.Hash{}, value) + // s.state.Commit(false) + + if value := s.state.GetState(address, common.Hash{}); value != (common.Hash{}) { + t.Errorf("expected empty current value, got %x", value) + } + if value := s.state.GetCommittedState(address, common.Hash{}); value != (common.Hash{}) { + t.Errorf("expected empty committed value, got %x", value) + } +} + +func TestSnapshot(t *testing.T) { + stateobjaddr := common.BytesToAddress([]byte("aa")) + var storageaddr common.Hash + data1 := common.BytesToHash([]byte{42}) + data2 := common.BytesToHash([]byte{43}) + s := newStateTest(t) + + // snapshot the genesis state + genesis := s.state.Snapshot() + + // set initial state object value + s.state.SetState(stateobjaddr, storageaddr, data1) + snapshot := s.state.Snapshot() + + // set a new state object value, revert it and ensure correct content + s.state.SetState(stateobjaddr, storageaddr, data2) + s.state.RevertToSnapshot(snapshot) + + if v := s.state.GetState(stateobjaddr, storageaddr); v != data1 { + t.Errorf("wrong storage value %v, want %v", v, data1) + } + if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) + } + + // revert up to the genesis state and ensure correct content + s.state.RevertToSnapshot(genesis) + if v := s.state.GetState(stateobjaddr, storageaddr); v != (common.Hash{}) { + t.Errorf("wrong storage value %v, want %v", v, common.Hash{}) + } + if v := s.state.GetCommittedState(stateobjaddr, storageaddr); v != (common.Hash{}) { + t.Errorf("wrong committed storage value %v, want %v", v, common.Hash{}) + } +} + +func TestSnapshotEmpty(t *testing.T) { + s := newStateTest(t) + s.state.RevertToSnapshot(s.state.Snapshot()) +} + +func TestSnapshot2(t *testing.T) { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + + stateobjaddr0 := common.BytesToAddress([]byte("so0")) + stateobjaddr1 := common.BytesToAddress([]byte("so1")) + var storageaddr common.Hash + + data0 := common.BytesToHash([]byte{17}) + data1 := common.BytesToHash([]byte{18}) + + state.SetState(stateobjaddr0, storageaddr, data0) + state.SetState(stateobjaddr1, storageaddr, data1) + + // db, trie are already non-empty values + so0 := state.getStateObject(stateobjaddr0) + so0.SetBalance(big.NewInt(42)) + so0.SetNonce(43) + so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) + so0.suicided = false + so0.deleted = false + state.setStateObject(so0) + + // root, _ := state.Commit(false) + // state, _ = New(root, state.db, state.snaps) + + // and one with deleted == true + so1 := state.getStateObject(stateobjaddr1) + so1.SetBalance(big.NewInt(52)) + so1.SetNonce(53) + so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) + so1.suicided = true + so1.deleted = true + state.setStateObject(so1) + + so1 = state.getStateObject(stateobjaddr1) + if so1 != nil { + t.Fatalf("deleted object not nil when getting") + } + + snapshot := state.Snapshot() + state.RevertToSnapshot(snapshot) + + so0Restored := state.getStateObject(stateobjaddr0) + // Update lazily-loaded values before comparing. + so0Restored.GetState(state.db, storageaddr) + so0Restored.Code(state.db) + // non-deleted is equal (restored) + compareStateObjects(so0Restored, so0, t) + + // deleted should be nil, both before and after restore of state copy + so1Restored := state.getStateObject(stateobjaddr1) + if so1Restored != nil { + t.Fatalf("deleted object not nil after restoring snapshot: %+v", so1Restored) + } +} + +func compareStateObjects(so0, so1 *stateObject, t *testing.T) { + if so0.Address() != so1.Address() { + t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address) + } + if so0.Balance().Cmp(so1.Balance()) != 0 { + t.Fatalf("Balance mismatch: have %v, want %v", so0.Balance(), so1.Balance()) + } + if so0.Nonce() != so1.Nonce() { + t.Fatalf("Nonce mismatch: have %v, want %v", so0.Nonce(), so1.Nonce()) + } + if so0.data.Root != so1.data.Root { + t.Errorf("Root mismatch: have %x, want %x", so0.data.Root[:], so1.data.Root[:]) + } + if !bytes.Equal(so0.CodeHash(), so1.CodeHash()) { + t.Fatalf("CodeHash mismatch: have %v, want %v", so0.CodeHash(), so1.CodeHash()) + } + if !bytes.Equal(so0.code, so1.code) { + t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code) + } + + if len(so1.dirtyStorage) != len(so0.dirtyStorage) { + t.Errorf("Dirty storage size mismatch: have %d, want %d", len(so1.dirtyStorage), len(so0.dirtyStorage)) + } + for k, v := range so1.dirtyStorage { + if so0.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want %v", k, so0.dirtyStorage[k], v) + } + } + for k, v := range so0.dirtyStorage { + if so1.dirtyStorage[k] != v { + t.Errorf("Dirty storage key %x mismatch: have %v, want none.", k, v) + } + } + if len(so1.originStorage) != len(so0.originStorage) { + t.Errorf("Origin storage size mismatch: have %d, want %d", len(so1.originStorage), len(so0.originStorage)) + } + for k, v := range so1.originStorage { + if so0.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want %v", k, so0.originStorage[k], v) + } + } + for k, v := range so0.originStorage { + if so1.originStorage[k] != v { + t.Errorf("Origin storage key %x mismatch: have %v, want none.", k, v) + } + } +} diff --git a/trie_by_cid/state/statedb.go b/trie_by_cid/state/statedb.go index fc0c0b8..b4c4369 100644 --- a/trie_by_cid/state/statedb.go +++ b/trie_by_cid/state/statedb.go @@ -1,17 +1,45 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package state provides a caching layer atop the Ethereum state trie. package state import ( "errors" "fmt" "math/big" + "sort" "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state/snapshot" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rlp" + + "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) +type revision struct { + id int + journalIndex int +} + type proofList [][]byte func (n *proofList) Put(key []byte, value []byte) error { @@ -28,46 +56,127 @@ func (n *proofList) Delete(key []byte) error { // nested states. It's the general query interface to retrieve: // * Contracts // * Accounts -// -// This implementation is read-only and performs no journaling, prefetching, or metrics tracking. type StateDB struct { - db Database - trie Trie - hasher crypto.KeccakState + db Database + prefetcher *triePrefetcher + trie Trie + hasher crypto.KeccakState + + // originalRoot is the pre-state root, before any changes were made. + // It will be updated when the Commit is called. + originalRoot common.Hash + + snaps *snapshot.Tree + snap snapshot.Snapshot + snapAccounts map[common.Hash][]byte + snapStorage map[common.Hash]map[common.Hash][]byte // This map holds 'live' objects, which will get modified while processing a state transition. - stateObjects map[common.Address]*stateObject + stateObjects map[common.Address]*stateObject + stateObjectsPending map[common.Address]struct{} // State objects finalized but not yet written to the trie + stateObjectsDirty map[common.Address]struct{} // State objects modified in the current execution + stateObjectsDestruct map[common.Address]struct{} // State objects destructed in the block // DB error. // State objects are used by the consensus core and VM which are // unable to deal with database-level errors. Any error that occurs - // during a database read is memoized here and will eventually be returned - // by StateDB.Commit. + // during a database read is memoized here and will eventually be + // returned by StateDB.Commit. Notably, this error is also shared + // by all cached state objects in case the database failure occurs + // when accessing state of accounts. dbErr error + // The refund counter, also used by state transitioning. + refund uint64 + + thash common.Hash + txIndex int + logs map[common.Hash][]*types.Log + logSize uint + preimages map[common.Hash][]byte + // Per-transaction access list + accessList *accessList + + // Transient storage + transientStorage transientStorage + + // Journal of state modifications. This is the backbone of + // Snapshot and RevertToSnapshot. + journal *journal + validRevisions []revision + nextRevisionId int + // Measurements gathered during execution for debugging purposes - AccountReads time.Duration - StorageReads time.Duration + AccountReads time.Duration + AccountHashes time.Duration + AccountUpdates time.Duration + StorageReads time.Duration + StorageHashes time.Duration + StorageUpdates time.Duration + SnapshotAccountReads time.Duration + SnapshotStorageReads time.Duration + + AccountUpdated int + StorageUpdated int + AccountDeleted int + StorageDeleted int } // New creates a new state from a given trie. -func New(root common.Hash, db Database) (*StateDB, error) { +func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error) { tr, err := db.OpenTrie(root) if err != nil { return nil, err } sdb := &StateDB{ - db: db, - trie: tr, - stateObjects: make(map[common.Address]*stateObject), - preimages: make(map[common.Hash][]byte), - hasher: crypto.NewKeccakState(), + db: db, + trie: tr, + originalRoot: root, + snaps: snaps, + stateObjects: make(map[common.Address]*stateObject), + stateObjectsPending: make(map[common.Address]struct{}), + stateObjectsDirty: make(map[common.Address]struct{}), + stateObjectsDestruct: make(map[common.Address]struct{}), + logs: make(map[common.Hash][]*types.Log), + preimages: make(map[common.Hash][]byte), + journal: newJournal(), + accessList: newAccessList(), + transientStorage: newTransientStorage(), + hasher: crypto.NewKeccakState(), + } + if sdb.snaps != nil { + if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil { + sdb.snapAccounts = make(map[common.Hash][]byte) + sdb.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + } } return sdb, nil } +// StartPrefetcher initializes a new trie prefetcher to pull in nodes from the +// state trie concurrently while the state is mutated so that when we reach the +// commit phase, most of the needed data is already hot. +func (s *StateDB) StartPrefetcher(namespace string) { + if s.prefetcher != nil { + s.prefetcher.close() + s.prefetcher = nil + } + if s.snap != nil { + s.prefetcher = newTriePrefetcher(s.db, s.originalRoot, namespace) + } +} + +// StopPrefetcher terminates a running prefetcher and reports any leftover stats +// from the gathered metrics. +func (s *StateDB) StopPrefetcher() { + if s.prefetcher != nil { + s.prefetcher.close() + s.prefetcher = nil + } +} + // setError remembers the first non-nil error it is called with. func (s *StateDB) setError(err error) { if s.dbErr == nil { @@ -75,17 +184,44 @@ func (s *StateDB) setError(err error) { } } +// Error returns the memorized database failure occurred earlier. func (s *StateDB) Error() error { return s.dbErr } func (s *StateDB) AddLog(log *types.Log) { - panic("unsupported") + s.journal.append(addLogChange{txhash: s.thash}) + + log.TxHash = s.thash + log.TxIndex = uint(s.txIndex) + log.Index = s.logSize + s.logs[s.thash] = append(s.logs[s.thash], log) + s.logSize++ +} + +// GetLogs returns the logs matching the specified transaction hash, and annotates +// them with the given blockNumber and blockHash. +func (s *StateDB) GetLogs(hash common.Hash, blockNumber uint64, blockHash common.Hash) []*types.Log { + logs := s.logs[hash] + for _, l := range logs { + l.BlockNumber = blockNumber + l.BlockHash = blockHash + } + return logs +} + +func (s *StateDB) Logs() []*types.Log { + var logs []*types.Log + for _, lgs := range s.logs { + logs = append(logs, lgs...) + } + return logs } // AddPreimage records a SHA3 preimage seen by the VM. func (s *StateDB) AddPreimage(hash common.Hash, preimage []byte) { if _, ok := s.preimages[hash]; !ok { + s.journal.append(addPreimageChange{hash: hash}) pi := make([]byte, len(preimage)) copy(pi, preimage) s.preimages[hash] = pi @@ -99,13 +235,18 @@ func (s *StateDB) Preimages() map[common.Hash][]byte { // AddRefund adds gas to the refund counter func (s *StateDB) AddRefund(gas uint64) { - panic("unsupported") + s.journal.append(refundChange{prev: s.refund}) + s.refund += gas } // SubRefund removes gas from the refund counter. // This method will panic if the refund counter goes below zero func (s *StateDB) SubRefund(gas uint64) { - panic("unsupported") + s.journal.append(refundChange{prev: s.refund}) + if gas > s.refund { + panic(fmt.Sprintf("Refund counter below zero (gas: %d > refund: %d)", gas, s.refund)) + } + s.refund -= gas } // Exist reports whether the given account address exists in the state. @@ -139,6 +280,11 @@ func (s *StateDB) GetNonce(addr common.Address) uint64 { return 0 } +// TxIndex returns the current transaction index set by Prepare. +func (s *StateDB) TxIndex() int { + return s.txIndex +} + func (s *StateDB) GetCode(addr common.Address) []byte { stateObject := s.getStateObject(addr) if stateObject != nil { @@ -186,18 +332,28 @@ func (s *StateDB) GetProofByHash(addrHash common.Hash) ([][]byte, error) { // GetStorageProof returns the Merkle proof for given storage slot. func (s *StateDB) GetStorageProof(a common.Address, key common.Hash) ([][]byte, error) { - var proof proofList - trie := s.StorageTrie(a) - if trie == nil { - return proof, errors.New("storage trie for requested address does not exist") + trie, err := s.StorageTrie(a) + if err != nil { + return nil, err } - err := trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) - return proof, err + if trie == nil { + return nil, errors.New("storage trie for requested address does not exist") + } + var proof proofList + err = trie.Prove(crypto.Keccak256(key.Bytes()), 0, &proof) + if err != nil { + return nil, err + } + return proof, nil } // GetCommittedState retrieves a value from the given account's committed storage trie. func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { - return s.GetState(addr, hash) + stateObject := s.getStateObject(addr) + if stateObject != nil { + return stateObject.GetCommittedState(s.db, hash) + } + return common.Hash{} } // Database retrieves the low level database supporting the lower level trie ops. @@ -205,17 +361,26 @@ func (s *StateDB) Database() Database { return s.db } -// StorageTrie returns the storage trie of an account. -// The return value is a copy and is nil for non-existent accounts. -func (s *StateDB) StorageTrie(addr common.Address) Trie { +// StorageTrie returns the storage trie of an account. The return value is a copy +// and is nil for non-existent accounts. An error will be returned if storage trie +// is existent but can't be loaded correctly. +func (s *StateDB) StorageTrie(addr common.Address) (Trie, error) { stateObject := s.getStateObject(addr) if stateObject == nil { - return nil + return nil, nil } - return stateObject.getTrie(s.db) + cpy := stateObject.deepCopy(s) + if _, err := cpy.updateTrie(s.db); err != nil { + return nil, err + } + return cpy.getTrie(s.db) } func (s *StateDB) HasSuicided(addr common.Address) bool { + stateObject := s.getStateObject(addr) + if stateObject != nil { + return stateObject.suicided + } return false } @@ -225,34 +390,61 @@ func (s *StateDB) HasSuicided(addr common.Address) bool { // AddBalance adds amount to the account associated with addr. func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.AddBalance(amount) + } } // SubBalance subtracts amount from the account associated with addr. func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SubBalance(amount) + } } func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetBalance(amount) + } } func (s *StateDB) SetNonce(addr common.Address, nonce uint64) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetNonce(nonce) + } } func (s *StateDB) SetCode(addr common.Address, code []byte) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetCode(crypto.Keccak256Hash(code), code) + } } func (s *StateDB) SetState(addr common.Address, key, value common.Hash) { - panic("unsupported") + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetState(s.db, key, value) + } } // SetStorage replaces the entire storage for the specified account with given // storage. This function should only be used for debugging. func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) { - panic("unsupported") + // SetStorage needs to wipe existing storage. We achieve this by pretending + // that the account self-destructed earlier in this block, by flagging + // it in stateObjectsDestruct. The effect of doing so is that storage lookups + // will not hit disk, since it is assumed that the disk-data is belonging + // to a previous incarnation of the object. + s.stateObjectsDestruct[addr] = struct{}{} + stateObject := s.GetOrNewStateObject(addr) + for k, v := range storage { + stateObject.SetState(s.db, k, v) + } } // Suicide marks the given account as suicided. @@ -261,36 +453,147 @@ func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common // The account's state object is still available until the state is committed, // getStateObject will return a non-nil account after Suicide. func (s *StateDB) Suicide(addr common.Address) bool { - panic("unsupported") - return false + stateObject := s.getStateObject(addr) + if stateObject == nil { + return false + } + s.journal.append(suicideChange{ + account: &addr, + prev: stateObject.suicided, + prevbalance: new(big.Int).Set(stateObject.Balance()), + }) + stateObject.markSuicided() + stateObject.data.Balance = new(big.Int) + + return true +} + +// SetTransientState sets transient storage for a given account. It +// adds the change to the journal so that it can be rolled back +// to its previous value if there is a revert. +func (s *StateDB) SetTransientState(addr common.Address, key, value common.Hash) { + prev := s.GetTransientState(addr, key) + if prev == value { + return + } + s.journal.append(transientStorageChange{ + account: &addr, + key: key, + prevalue: prev, + }) + s.setTransientState(addr, key, value) +} + +// setTransientState is a lower level setter for transient storage. It +// is called during a revert to prevent modifications to the journal. +func (s *StateDB) setTransientState(addr common.Address, key, value common.Hash) { + s.transientStorage.Set(addr, key, value) +} + +// GetTransientState gets transient storage for a given account. +func (s *StateDB) GetTransientState(addr common.Address, key common.Hash) common.Hash { + return s.transientStorage.Get(addr, key) } // // Setting, updating & deleting state object methods. // +// updateStateObject writes the given object to the trie. +func (s *StateDB) updateStateObject(obj *stateObject) { + // Track the amount of time wasted on updating the account from the trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now()) + } + // Encode the account and update the account trie + addr := obj.Address() + if err := s.trie.TryUpdateAccount(addr, &obj.data); err != nil { + s.setError(fmt.Errorf("updateStateObject (%x) error: %v", addr[:], err)) + } + + // If state snapshotting is active, cache the data til commit. Note, this + // update mechanism is not symmetric to the deletion, because whereas it is + // enough to track account updates at commit time, deletions need tracking + // at transaction boundary level to ensure we capture state clearing. + if s.snap != nil { + s.snapAccounts[obj.addrHash] = snapshot.SlimAccountRLP(obj.data.Nonce, obj.data.Balance, obj.data.Root, obj.data.CodeHash) + } +} + +// deleteStateObject removes the given object from the state trie. +func (s *StateDB) deleteStateObject(obj *stateObject) { + // Track the amount of time wasted on deleting the account from the trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountUpdates += time.Since(start) }(time.Now()) + } + // Delete the account from the trie + addr := obj.Address() + if err := s.trie.TryDeleteAccount(addr); err != nil { + s.setError(fmt.Errorf("deleteStateObject (%x) error: %v", addr[:], err)) + } +} + // getStateObject retrieves a state object given by the address, returning nil if -// the object is not found or was deleted in this execution context. +// the object is not found or was deleted in this execution context. If you need +// to differentiate between non-existent/just-deleted, use getDeletedStateObject. func (s *StateDB) getStateObject(addr common.Address) *stateObject { + if obj := s.getDeletedStateObject(addr); obj != nil && !obj.deleted { + return obj + } + return nil +} + +// getDeletedStateObject is similar to getStateObject, but instead of returning +// nil for a deleted state object, it returns the actual object with the deleted +// flag set. This is needed by the state journal to revert to the correct s- +// destructed object instead of wiping all knowledge about the state object. +func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject { // Prefer live objects if any is available if obj := s.stateObjects[addr]; obj != nil { return obj } - // If no live objects are available, load from the database - start := time.Now() - var err error - data, err := s.trie.TryGetAccount(addr.Bytes()) - if metrics.EnabledExpensive { - s.AccountReads += time.Since(start) - } - if err != nil { - s.setError(fmt.Errorf("getStateObject (%x) error: %w", addr.Bytes(), err)) - return nil + // If no live objects are available, attempt to use snapshots + var data *types.StateAccount + if s.snap != nil { + start := time.Now() + acc, err := s.snap.Account(crypto.HashData(s.hasher, addr.Bytes())) + if metrics.EnabledExpensive { + s.SnapshotAccountReads += time.Since(start) + } + if err == nil { + if acc == nil { + return nil + } + data = &types.StateAccount{ + Nonce: acc.Nonce, + Balance: acc.Balance, + CodeHash: acc.CodeHash, + Root: common.BytesToHash(acc.Root), + } + if len(data.CodeHash) == 0 { + data.CodeHash = types.EmptyCodeHash.Bytes() + } + if data.Root == (common.Hash{}) { + data.Root = types.EmptyRootHash + } + } } + // If snapshot unavailable or reading from it failed, load from the database if data == nil { - return nil + start := time.Now() + var err error + data, err = s.trie.TryGetAccount(addr) + if metrics.EnabledExpensive { + s.AccountReads += time.Since(start) + } + if err != nil { + s.setError(fmt.Errorf("getDeleteStateObject (%x) error: %w", addr.Bytes(), err)) + return nil + } + if data == nil { + return nil + } } - // Insert into the live set obj := newObject(s, addr, *data) s.setStateObject(obj) @@ -301,69 +604,439 @@ func (s *StateDB) setStateObject(object *stateObject) { s.stateObjects[object.Address()] = object } +// GetOrNewStateObject retrieves a state object or create a new state object if nil. +func (s *StateDB) GetOrNewStateObject(addr common.Address) *stateObject { + stateObject := s.getStateObject(addr) + if stateObject == nil { + stateObject, _ = s.createObject(addr) + } + return stateObject +} + +// createObject creates a new state object. If there is an existing account with +// the given address, it is overwritten and returned as the second return value. +func (s *StateDB) createObject(addr common.Address) (newobj, prev *stateObject) { + prev = s.getDeletedStateObject(addr) // Note, prev might have been deleted, we need that! + + var prevdestruct bool + if prev != nil { + _, prevdestruct = s.stateObjectsDestruct[prev.address] + if !prevdestruct { + s.stateObjectsDestruct[prev.address] = struct{}{} + } + } + newobj = newObject(s, addr, types.StateAccount{}) + if prev == nil { + s.journal.append(createObjectChange{account: &addr}) + } else { + s.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct}) + } + s.setStateObject(newobj) + if prev != nil && !prev.deleted { + return newobj, prev + } + return newobj, nil +} + // CreateAccount explicitly creates a state object. If a state object with the address // already exists the balance is carried over to the new account. // // CreateAccount is called during the EVM CREATE operation. The situation might arise that // a contract does the following: // -// 1. sends funds to sha(account ++ (nonce + 1)) -// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. func (s *StateDB) CreateAccount(addr common.Address) { - panic("unsupported") + newObj, prev := s.createObject(addr) + if prev != nil { + newObj.setBalance(prev.data.Balance) + } } func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) error { + so := db.getStateObject(addr) + if so == nil { + return nil + } + tr, err := so.getTrie(db.db) + if err != nil { + return err + } + it := trie.NewIterator(tr.NodeIterator(nil)) + + for it.Next() { + key := common.BytesToHash(db.trie.GetKey(it.Key)) + if value, dirty := so.dirtyStorage[key]; dirty { + if !cb(key, value) { + return nil + } + continue + } + + if len(it.Value) > 0 { + _, content, _, err := rlp.Split(it.Value) + if err != nil { + return err + } + if !cb(key, common.BytesToHash(content)) { + return nil + } + } + } return nil } +// Copy creates a deep, independent copy of the state. +// Snapshots of the copied state cannot be applied to the copy. +func (s *StateDB) Copy() *StateDB { + // Copy all the basic fields, initialize the memory ones + state := &StateDB{ + db: s.db, + trie: s.db.CopyTrie(s.trie), + originalRoot: s.originalRoot, + stateObjects: make(map[common.Address]*stateObject, len(s.journal.dirties)), + stateObjectsPending: make(map[common.Address]struct{}, len(s.stateObjectsPending)), + stateObjectsDirty: make(map[common.Address]struct{}, len(s.journal.dirties)), + stateObjectsDestruct: make(map[common.Address]struct{}, len(s.stateObjectsDestruct)), + refund: s.refund, + logs: make(map[common.Hash][]*types.Log, len(s.logs)), + logSize: s.logSize, + preimages: make(map[common.Hash][]byte, len(s.preimages)), + journal: newJournal(), + hasher: crypto.NewKeccakState(), + } + // Copy the dirty states, logs, and preimages + for addr := range s.journal.dirties { + // As documented [here](https://github.com/ethereum/go-ethereum/pull/16485#issuecomment-380438527), + // and in the Finalise-method, there is a case where an object is in the journal but not + // in the stateObjects: OOG after touch on ripeMD prior to Byzantium. Thus, we need to check for + // nil + if object, exist := s.stateObjects[addr]; exist { + // Even though the original object is dirty, we are not copying the journal, + // so we need to make sure that any side-effect the journal would have caused + // during a commit (or similar op) is already applied to the copy. + state.stateObjects[addr] = object.deepCopy(state) + + state.stateObjectsDirty[addr] = struct{}{} // Mark the copy dirty to force internal (code/state) commits + state.stateObjectsPending[addr] = struct{}{} // Mark the copy pending to force external (account) commits + } + } + // Above, we don't copy the actual journal. This means that if the copy + // is copied, the loop above will be a no-op, since the copy's journal + // is empty. Thus, here we iterate over stateObjects, to enable copies + // of copies. + for addr := range s.stateObjectsPending { + if _, exist := state.stateObjects[addr]; !exist { + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) + } + state.stateObjectsPending[addr] = struct{}{} + } + for addr := range s.stateObjectsDirty { + if _, exist := state.stateObjects[addr]; !exist { + state.stateObjects[addr] = s.stateObjects[addr].deepCopy(state) + } + state.stateObjectsDirty[addr] = struct{}{} + } + // Deep copy the destruction flag. + for addr := range s.stateObjectsDestruct { + state.stateObjectsDestruct[addr] = struct{}{} + } + for hash, logs := range s.logs { + cpy := make([]*types.Log, len(logs)) + for i, l := range logs { + cpy[i] = new(types.Log) + *cpy[i] = *l + } + state.logs[hash] = cpy + } + for hash, preimage := range s.preimages { + state.preimages[hash] = preimage + } + // Do we need to copy the access list and transient storage? + // In practice: No. At the start of a transaction, these two lists are empty. + // In practice, we only ever copy state _between_ transactions/blocks, never + // in the middle of a transaction. However, it doesn't cost us much to copy + // empty lists, so we do it anyway to not blow up if we ever decide copy them + // in the middle of a transaction. + state.accessList = s.accessList.Copy() + state.transientStorage = s.transientStorage.Copy() + + // If there's a prefetcher running, make an inactive copy of it that can + // only access data but does not actively preload (since the user will not + // know that they need to explicitly terminate an active copy). + if s.prefetcher != nil { + state.prefetcher = s.prefetcher.copy() + } + if s.snaps != nil { + // In order for the miner to be able to use and make additions + // to the snapshot tree, we need to copy that as well. + // Otherwise, any block mined by ourselves will cause gaps in the tree, + // and force the miner to operate trie-backed only + state.snaps = s.snaps + state.snap = s.snap + + // deep copy needed + state.snapAccounts = make(map[common.Hash][]byte) + for k, v := range s.snapAccounts { + state.snapAccounts[k] = v + } + state.snapStorage = make(map[common.Hash]map[common.Hash][]byte) + for k, v := range s.snapStorage { + temp := make(map[common.Hash][]byte) + for kk, vv := range v { + temp[kk] = vv + } + state.snapStorage[k] = temp + } + } + return state +} + // Snapshot returns an identifier for the current revision of the state. func (s *StateDB) Snapshot() int { - return 0 + id := s.nextRevisionId + s.nextRevisionId++ + s.validRevisions = append(s.validRevisions, revision{id, s.journal.length()}) + return id } // RevertToSnapshot reverts all state changes made since the given revision. func (s *StateDB) RevertToSnapshot(revid int) { - panic("unsupported") + // Find the snapshot in the stack of valid snapshots. + idx := sort.Search(len(s.validRevisions), func(i int) bool { + return s.validRevisions[i].id >= revid + }) + if idx == len(s.validRevisions) || s.validRevisions[idx].id != revid { + panic(fmt.Errorf("revision id %v cannot be reverted", revid)) + } + snapshot := s.validRevisions[idx].journalIndex + + // Replay the journal to undo changes and remove invalidated snapshots + s.journal.revert(s, snapshot) + s.validRevisions = s.validRevisions[:idx] } // GetRefund returns the current value of the refund counter. func (s *StateDB) GetRefund() uint64 { - panic("unsupported") - return 0 + return s.refund } -// PrepareAccessList handles the preparatory steps for executing a state transition with -// regards to both EIP-2929 and EIP-2930: +// Finalise finalises the state by removing the destructed objects and clears +// the journal as well as the refunds. Finalise, however, will not push any updates +// into the tries just yet. Only IntermediateRoot or Commit will do that. +func (s *StateDB) Finalise(deleteEmptyObjects bool) { + addressesToPrefetch := make([][]byte, 0, len(s.journal.dirties)) + for addr := range s.journal.dirties { + obj, exist := s.stateObjects[addr] + if !exist { + // ripeMD is 'touched' at block 1714175, in tx 0x1237f737031e40bcde4a8b7e717b2d15e3ecadfe49bb1bbc71ee9deb09c6fcf2 + // That tx goes out of gas, and although the notion of 'touched' does not exist there, the + // touch-event will still be recorded in the journal. Since ripeMD is a special snowflake, + // it will persist in the journal even though the journal is reverted. In this special circumstance, + // it may exist in `s.journal.dirties` but not in `s.stateObjects`. + // Thus, we can safely ignore it here + continue + } + if obj.suicided || (deleteEmptyObjects && obj.empty()) { + obj.deleted = true + + // We need to maintain account deletions explicitly (will remain + // set indefinitely). + s.stateObjectsDestruct[obj.address] = struct{}{} + + // If state snapshotting is active, also mark the destruction there. + // Note, we can't do this only at the end of a block because multiple + // transactions within the same block might self destruct and then + // resurrect an account; but the snapshotter needs both events. + if s.snap != nil { + delete(s.snapAccounts, obj.addrHash) // Clear out any previously updated account data (may be recreated via a resurrect) + delete(s.snapStorage, obj.addrHash) // Clear out any previously updated storage data (may be recreated via a resurrect) + } + } else { + obj.finalise(true) // Prefetch slots in the background + } + s.stateObjectsPending[addr] = struct{}{} + s.stateObjectsDirty[addr] = struct{}{} + + // At this point, also ship the address off to the precacher. The precacher + // will start loading tries, and when the change is eventually committed, + // the commit-phase will be a lot faster + addressesToPrefetch = append(addressesToPrefetch, common.CopyBytes(addr[:])) // Copy needed for closure + } + if s.prefetcher != nil && len(addressesToPrefetch) > 0 { + s.prefetcher.prefetch(common.Hash{}, s.originalRoot, addressesToPrefetch) + } + // Invalidate journal because reverting across transactions is not allowed. + s.clearJournalAndRefund() +} + +// IntermediateRoot computes the current root hash of the state trie. +// It is called in between transactions to get the root hash that +// goes into transaction receipts. +func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash { + // Finalise all the dirty storage states and write them into the tries + s.Finalise(deleteEmptyObjects) + + // If there was a trie prefetcher operating, it gets aborted and irrevocably + // modified after we start retrieving tries. Remove it from the statedb after + // this round of use. + // + // This is weird pre-byzantium since the first tx runs with a prefetcher and + // the remainder without, but pre-byzantium even the initial prefetcher is + // useless, so no sleep lost. + prefetcher := s.prefetcher + if s.prefetcher != nil { + defer func() { + s.prefetcher.close() + s.prefetcher = nil + }() + } + // Although naively it makes sense to retrieve the account trie and then do + // the contract storage and account updates sequentially, that short circuits + // the account prefetcher. Instead, let's process all the storage updates + // first, giving the account prefetches just a few more milliseconds of time + // to pull useful data from disk. + for addr := range s.stateObjectsPending { + if obj := s.stateObjects[addr]; !obj.deleted { + obj.updateRoot(s.db) + } + } + // Now we're about to start to write changes to the trie. The trie is so far + // _untouched_. We can check with the prefetcher, if it can give us a trie + // which has the same root, but also has some content loaded into it. + if prefetcher != nil { + if trie := prefetcher.trie(common.Hash{}, s.originalRoot); trie != nil { + s.trie = trie + } + } + usedAddrs := make([][]byte, 0, len(s.stateObjectsPending)) + for addr := range s.stateObjectsPending { + if obj := s.stateObjects[addr]; obj.deleted { + s.deleteStateObject(obj) + s.AccountDeleted += 1 + } else { + s.updateStateObject(obj) + s.AccountUpdated += 1 + } + usedAddrs = append(usedAddrs, common.CopyBytes(addr[:])) // Copy needed for closure + } + if prefetcher != nil { + prefetcher.used(common.Hash{}, s.originalRoot, usedAddrs) + } + if len(s.stateObjectsPending) > 0 { + s.stateObjectsPending = make(map[common.Address]struct{}) + } + // Track the amount of time wasted on hashing the account trie + if metrics.EnabledExpensive { + defer func(start time.Time) { s.AccountHashes += time.Since(start) }(time.Now()) + } + return s.trie.Hash() +} + +// SetTxContext sets the current transaction hash and index which are +// used when the EVM emits new state logs. It should be invoked before +// transaction execution. +func (s *StateDB) SetTxContext(thash common.Hash, ti int) { + s.thash = thash + s.txIndex = ti +} + +func (s *StateDB) clearJournalAndRefund() { + if len(s.journal.entries) > 0 { + s.journal = newJournal() + s.refund = 0 + } + s.validRevisions = s.validRevisions[:0] // Snapshots can be created without journal entries +} + +// Prepare handles the preparatory steps for executing a state transition with. +// This method must be invoked before state transition. // +// Berlin fork: // - Add sender to access list (2929) // - Add destination to access list (2929) // - Add precompiles to access list (2929) // - Add the contents of the optional tx access list (2930) // -// This method should only be called if Berlin/2929+2930 is applicable at the current number. -func (s *StateDB) PrepareAccessList(sender common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) { - panic("unsupported") +// Potential EIPs: +// - Reset access list (Berlin) +// - Add coinbase to access list (EIP-3651) +// - Reset transient storage (EIP-1153) +func (s *StateDB) Prepare(rules params.Rules, sender, coinbase common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) { + if rules.IsBerlin { + // Clear out any leftover from previous executions + al := newAccessList() + s.accessList = al + + al.AddAddress(sender) + if dst != nil { + al.AddAddress(*dst) + // If it's a create-tx, the destination will be added inside evm.create + } + for _, addr := range precompiles { + al.AddAddress(addr) + } + for _, el := range list { + al.AddAddress(el.Address) + for _, key := range el.StorageKeys { + al.AddSlot(el.Address, key) + } + } + if rules.IsShanghai { // EIP-3651: warm coinbase + al.AddAddress(coinbase) + } + } + // Reset transient storage at the beginning of transaction execution + s.transientStorage = newTransientStorage() } // AddAddressToAccessList adds the given address to the access list func (s *StateDB) AddAddressToAccessList(addr common.Address) { - panic("unsupported") + if s.accessList.AddAddress(addr) { + s.journal.append(accessListAddAccountChange{&addr}) + } } // AddSlotToAccessList adds the given (address, slot)-tuple to the access list func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { - panic("unsupported") + addrMod, slotMod := s.accessList.AddSlot(addr, slot) + if addrMod { + // In practice, this should not happen, since there is no way to enter the + // scope of 'address' without having the 'address' become already added + // to the access list (via call-variant, create, etc). + // Better safe than sorry, though + s.journal.append(accessListAddAccountChange{&addr}) + } + if slotMod { + s.journal.append(accessListAddSlotChange{ + address: &addr, + slot: &slot, + }) + } } // AddressInAccessList returns true if the given address is in the access list. func (s *StateDB) AddressInAccessList(addr common.Address) bool { - return false + return s.accessList.ContainsAddress(addr) } // SlotInAccessList returns true if the given (address, slot)-tuple is in the access list. func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { - return + return s.accessList.Contains(addr, slot) +} + +// convertAccountSet converts a provided account set from address keyed to hash keyed. +func (s *StateDB) convertAccountSet(set map[common.Address]struct{}) map[common.Hash]struct{} { + ret := make(map[common.Hash]struct{}) + for addr := range set { + obj, exist := s.stateObjects[addr] + if !exist { + ret[crypto.Keccak256Hash(addr[:])] = struct{}{} + } else { + ret[obj.addrHash] = struct{}{} + } + } + return ret } diff --git a/trie_by_cid/state/statedb_test.go b/trie_by_cid/state/statedb_test.go new file mode 100644 index 0000000..60cd66a --- /dev/null +++ b/trie_by_cid/state/statedb_test.go @@ -0,0 +1,594 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "math/big" + "math/rand" + "reflect" + "strings" + "sync" + "testing" + "testing/quick" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" +) + +// TestCopy tests that copying a StateDB object indeed makes the original and +// the copy independent of each other. This test is a regression test against +// https://github.com/ethereum/go-ethereum/pull/15549. +func TestCopy(t *testing.T) { + // Create a random state test to copy and modify "independently" + orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + + for i := byte(0); i < 255; i++ { + obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + obj.AddBalance(big.NewInt(int64(i))) + orig.updateStateObject(obj) + } + orig.Finalise(false) + + // Copy the state + copy := orig.Copy() + + // Copy the copy state + ccopy := copy.Copy() + + // modify all in memory + for i := byte(0); i < 255; i++ { + origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + + origObj.AddBalance(big.NewInt(2 * int64(i))) + copyObj.AddBalance(big.NewInt(3 * int64(i))) + ccopyObj.AddBalance(big.NewInt(4 * int64(i))) + + orig.updateStateObject(origObj) + copy.updateStateObject(copyObj) + ccopy.updateStateObject(copyObj) + } + + // Finalise the changes on all concurrently + finalise := func(wg *sync.WaitGroup, db *StateDB) { + defer wg.Done() + db.Finalise(true) + } + + var wg sync.WaitGroup + wg.Add(3) + go finalise(&wg, orig) + go finalise(&wg, copy) + go finalise(&wg, ccopy) + wg.Wait() + + // Verify that the three states have been updated independently + for i := byte(0); i < 255; i++ { + origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + ccopyObj := ccopy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) + + if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 { + t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want) + } + if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 { + t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want) + } + if want := big.NewInt(5 * int64(i)); ccopyObj.Balance().Cmp(want) != 0 { + t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, ccopyObj.Balance(), want) + } + } +} + +func TestSnapshotRandom(t *testing.T) { + config := &quick.Config{MaxCount: 1000} + err := quick.Check((*snapshotTest).run, config) + if cerr, ok := err.(*quick.CheckError); ok { + test := cerr.In[0].(*snapshotTest) + t.Errorf("%v:\n%s", test.err, test) + } else if err != nil { + t.Error(err) + } +} + +// A snapshotTest checks that reverting StateDB snapshots properly undoes all changes +// captured by the snapshot. Instances of this test with pseudorandom content are created +// by Generate. +// +// The test works as follows: +// +// A new state is created and all actions are applied to it. Several snapshots are taken +// in between actions. The test then reverts each snapshot. For each snapshot the actions +// leading up to it are replayed on a fresh, empty state. The behaviour of all public +// accessor methods on the reverted state must match the return value of the equivalent +// methods on the replayed state. +type snapshotTest struct { + addrs []common.Address // all account addresses + actions []testAction // modifications to the state + snapshots []int // actions indexes at which snapshot is taken + err error // failure details are reported through this field +} + +type testAction struct { + name string + fn func(testAction, *StateDB) + args []int64 + noAddr bool +} + +// newTestAction creates a random action that changes state. +func newTestAction(addr common.Address, r *rand.Rand) testAction { + actions := []testAction{ + { + name: "SetBalance", + fn: func(a testAction, s *StateDB) { + s.SetBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "AddBalance", + fn: func(a testAction, s *StateDB) { + s.AddBalance(addr, big.NewInt(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetNonce", + fn: func(a testAction, s *StateDB) { + s.SetNonce(addr, uint64(a.args[0])) + }, + args: make([]int64, 1), + }, + { + name: "SetState", + fn: func(a testAction, s *StateDB) { + var key, val common.Hash + binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) + binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) + s.SetState(addr, key, val) + }, + args: make([]int64, 2), + }, + { + name: "SetCode", + fn: func(a testAction, s *StateDB) { + code := make([]byte, 16) + binary.BigEndian.PutUint64(code, uint64(a.args[0])) + binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) + s.SetCode(addr, code) + }, + args: make([]int64, 2), + }, + { + name: "CreateAccount", + fn: func(a testAction, s *StateDB) { + s.CreateAccount(addr) + }, + }, + { + name: "Suicide", + fn: func(a testAction, s *StateDB) { + s.Suicide(addr) + }, + }, + { + name: "AddRefund", + fn: func(a testAction, s *StateDB) { + s.AddRefund(uint64(a.args[0])) + }, + args: make([]int64, 1), + noAddr: true, + }, + { + name: "AddLog", + fn: func(a testAction, s *StateDB) { + data := make([]byte, 2) + binary.BigEndian.PutUint16(data, uint16(a.args[0])) + s.AddLog(&types.Log{Address: addr, Data: data}) + }, + args: make([]int64, 1), + }, + { + name: "AddPreimage", + fn: func(a testAction, s *StateDB) { + preimage := []byte{1} + hash := common.BytesToHash(preimage) + s.AddPreimage(hash, preimage) + }, + args: make([]int64, 1), + }, + { + name: "AddAddressToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddAddressToAccessList(addr) + }, + }, + { + name: "AddSlotToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddSlotToAccessList(addr, + common.Hash{byte(a.args[0])}) + }, + args: make([]int64, 1), + }, + { + name: "SetTransientState", + fn: func(a testAction, s *StateDB) { + var key, val common.Hash + binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) + binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) + s.SetTransientState(addr, key, val) + }, + args: make([]int64, 2), + }, + } + action := actions[r.Intn(len(actions))] + var nameargs []string + if !action.noAddr { + nameargs = append(nameargs, addr.Hex()) + } + for i := range action.args { + action.args[i] = rand.Int63n(100) + nameargs = append(nameargs, fmt.Sprint(action.args[i])) + } + action.name += strings.Join(nameargs, ", ") + return action +} + +// Generate returns a new snapshot test of the given size. All randomness is +// derived from r. +func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { + // Generate random actions. + addrs := make([]common.Address, 50) + for i := range addrs { + addrs[i][0] = byte(i) + } + actions := make([]testAction, size) + for i := range actions { + addr := addrs[r.Intn(len(addrs))] + actions[i] = newTestAction(addr, r) + } + // Generate snapshot indexes. + nsnapshots := int(math.Sqrt(float64(size))) + if size > 0 && nsnapshots == 0 { + nsnapshots = 1 + } + snapshots := make([]int, nsnapshots) + snaplen := len(actions) / nsnapshots + for i := range snapshots { + // Try to place the snapshots some number of actions apart from each other. + snapshots[i] = (i * snaplen) + r.Intn(snaplen) + } + return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) +} + +func (test *snapshotTest) String() string { + out := new(bytes.Buffer) + sindex := 0 + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) + sindex++ + } + fmt.Fprintf(out, "%4d: %s\n", i, action.name) + } + return out.String() +} + +func (test *snapshotTest) run() bool { + // Run all actions and create snapshots. + var ( + state, _ = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + snapshotRevs = make([]int, len(test.snapshots)) + sindex = 0 + ) + for i, action := range test.actions { + if len(test.snapshots) > sindex && i == test.snapshots[sindex] { + snapshotRevs[sindex] = state.Snapshot() + sindex++ + } + action.fn(action, state) + } + // Revert all snapshots in reverse order. Each revert must yield a state + // that is equivalent to fresh state with all actions up the snapshot applied. + for sindex--; sindex >= 0; sindex-- { + checkstate, _ := New(common.Hash{}, state.Database(), nil) + for _, action := range test.actions[:test.snapshots[sindex]] { + action.fn(action, checkstate) + } + state.RevertToSnapshot(snapshotRevs[sindex]) + if err := test.checkEqual(state, checkstate); err != nil { + test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) + return false + } + } + return true +} + +// checkEqual checks that methods of state and checkstate return the same values. +func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { + for _, addr := range test.addrs { + var err error + checkeq := func(op string, a, b interface{}) bool { + if err == nil && !reflect.DeepEqual(a, b) { + err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) + return false + } + return true + } + // Check basic accessor methods. + checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) + checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr)) + checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) + checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) + checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) + checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) + checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) + // Check storage. + if obj := state.getStateObject(addr); obj != nil { + state.ForEachStorage(addr, func(key, value common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + }) + checkstate.ForEachStorage(addr, func(key, value common.Hash) bool { + return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value) + }) + } + if err != nil { + return err + } + } + + if state.GetRefund() != checkstate.GetRefund() { + return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", + state.GetRefund(), checkstate.GetRefund()) + } + if !reflect.DeepEqual(state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) { + return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", + state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{})) + } + return nil +} + +// TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy. +// See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512 +func TestCopyOfCopy(t *testing.T) { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + addr := common.HexToAddress("aaaa") + state.SetBalance(addr, big.NewInt(42)) + + if got := state.Copy().GetBalance(addr).Uint64(); got != 42 { + t.Fatalf("1st copy fail, expected 42, got %v", got) + } + if got := state.Copy().Copy().GetBalance(addr).Uint64(); got != 42 { + t.Fatalf("2nd copy fail, expected 42, got %v", got) + } +} + +func TestStateDBAccessList(t *testing.T) { + // Some helpers + addr := func(a string) common.Address { + return common.HexToAddress(a) + } + slot := func(a string) common.Hash { + return common.HexToHash(a) + } + + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + state, _ := New(common.Hash{}, db, nil) + state.accessList = newAccessList() + + verifyAddrs := func(astrings ...string) { + t.Helper() + // convert to common.Address form + var addresses []common.Address + var addressMap = make(map[common.Address]struct{}) + for _, astring := range astrings { + address := addr(astring) + addresses = append(addresses, address) + addressMap[address] = struct{}{} + } + // Check that the given addresses are in the access list + for _, address := range addresses { + if !state.AddressInAccessList(address) { + t.Fatalf("expected %x to be in access list", address) + } + } + // Check that only the expected addresses are present in the access list + for address := range state.accessList.addresses { + if _, exist := addressMap[address]; !exist { + t.Fatalf("extra address %x in access list", address) + } + } + } + verifySlots := func(addrString string, slotStrings ...string) { + if !state.AddressInAccessList(addr(addrString)) { + t.Fatalf("scope missing address/slots %v", addrString) + } + var address = addr(addrString) + // convert to common.Hash form + var slots []common.Hash + var slotMap = make(map[common.Hash]struct{}) + for _, slotString := range slotStrings { + s := slot(slotString) + slots = append(slots, s) + slotMap[s] = struct{}{} + } + // Check that the expected items are in the access list + for i, s := range slots { + if _, slotPresent := state.SlotInAccessList(address, s); !slotPresent { + t.Fatalf("input %d: scope missing slot %v (address %v)", i, s, addrString) + } + } + // Check that no extra elements are in the access list + index := state.accessList.addresses[address] + if index >= 0 { + stateSlots := state.accessList.slots[index] + for s := range stateSlots { + if _, slotPresent := slotMap[s]; !slotPresent { + t.Fatalf("scope has extra slot %v (address %v)", s, addrString) + } + } + } + } + + state.AddAddressToAccessList(addr("aa")) // 1 + state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3 + state.AddSlotToAccessList(addr("bb"), slot("02")) // 4 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + // Make a copy + stateCopy1 := state.Copy() + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + // same again, should cause no journal entries + state.AddSlotToAccessList(addr("bb"), slot("01")) + state.AddSlotToAccessList(addr("bb"), slot("02")) + state.AddAddressToAccessList(addr("aa")) + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + // some new ones + state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 + state.AddSlotToAccessList(addr("aa"), slot("01")) // 6 + state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8 + state.AddAddressToAccessList(addr("cc")) + if exp, got := 8, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + verifySlots("cc", "01") + + // now start rolling back changes + state.journal.revert(state, 7) + if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 6) + if state.AddressInAccessList(addr("cc")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 5) + if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02", "03") + + state.journal.revert(state, 4) + if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + state.journal.revert(state, 3) + if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01") + + state.journal.revert(state, 2) + if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + + state.journal.revert(state, 1) + if state.AddressInAccessList(addr("bb")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa") + + state.journal.revert(state, 0) + if state.AddressInAccessList(addr("aa")) { + t.Fatalf("addr present, expected missing") + } + if got, exp := len(state.accessList.addresses), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + // Check the copy + // Make a copy + state = stateCopy1 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + if got, exp := len(state.accessList.addresses), 2; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 1; got != exp { + t.Fatalf("expected empty, got %d", got) + } +} + +func TestStateDBTransientStorage(t *testing.T) { + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + state, _ := New(common.Hash{}, db, nil) + + key := common.Hash{0x01} + value := common.Hash{0x02} + addr := common.Address{} + + state.SetTransientState(addr, key, value) + if exp, got := 1, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + // the retrieved value should equal what was set + if got := state.GetTransientState(addr, key); got != value { + t.Fatalf("transient storage mismatch: have %x, want %x", got, value) + } + + // revert the transient state being set and then check that the + // value is now the empty hash + state.journal.revert(state, 0) + if got, exp := state.GetTransientState(addr, key), (common.Hash{}); exp != got { + t.Fatalf("transient storage mismatch: have %x, want %x", got, exp) + } + + // set transient state and then copy the statedb and ensure that + // the transient state is copied + state.SetTransientState(addr, key, value) + cpy := state.Copy() + if got := cpy.GetTransientState(addr, key); got != value { + t.Fatalf("transient storage mismatch: have %x, want %x", got, value) + } +} diff --git a/trie_by_cid/state/transient_storage.go b/trie_by_cid/state/transient_storage.go new file mode 100644 index 0000000..66e563e --- /dev/null +++ b/trie_by_cid/state/transient_storage.go @@ -0,0 +1,55 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "github.com/ethereum/go-ethereum/common" +) + +// transientStorage is a representation of EIP-1153 "Transient Storage". +type transientStorage map[common.Address]Storage + +// newTransientStorage creates a new instance of a transientStorage. +func newTransientStorage() transientStorage { + return make(transientStorage) +} + +// Set sets the transient-storage `value` for `key` at the given `addr`. +func (t transientStorage) Set(addr common.Address, key, value common.Hash) { + if _, ok := t[addr]; !ok { + t[addr] = make(Storage) + } + t[addr][key] = value +} + +// Get gets the transient storage for `key` at the given `addr`. +func (t transientStorage) Get(addr common.Address, key common.Hash) common.Hash { + val, ok := t[addr] + if !ok { + return common.Hash{} + } + return val[key] +} + +// Copy does a deep copy of the transientStorage +func (t transientStorage) Copy() transientStorage { + storage := make(transientStorage) + for key, value := range t { + storage[key] = value.Copy() + } + return storage +} diff --git a/trie_by_cid/state/trie_prefetcher.go b/trie_by_cid/state/trie_prefetcher.go new file mode 100644 index 0000000..5dd1b5b --- /dev/null +++ b/trie_by_cid/state/trie_prefetcher.go @@ -0,0 +1,354 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/metrics" + log "github.com/sirupsen/logrus" +) + +var ( + // triePrefetchMetricsPrefix is the prefix under which to publish the metrics. + triePrefetchMetricsPrefix = "trie/prefetch/" +) + +// triePrefetcher is an active prefetcher, which receives accounts or storage +// items and does trie-loading of them. The goal is to get as much useful content +// into the caches as possible. +// +// Note, the prefetcher's API is not thread safe. +type triePrefetcher struct { + db Database // Database to fetch trie nodes through + root common.Hash // Root hash of the account trie for metrics + fetches map[string]Trie // Partially or fully fetcher tries + fetchers map[string]*subfetcher // Subfetchers for each trie + + deliveryMissMeter metrics.Meter + accountLoadMeter metrics.Meter + accountDupMeter metrics.Meter + accountSkipMeter metrics.Meter + accountWasteMeter metrics.Meter + storageLoadMeter metrics.Meter + storageDupMeter metrics.Meter + storageSkipMeter metrics.Meter + storageWasteMeter metrics.Meter +} + +func newTriePrefetcher(db Database, root common.Hash, namespace string) *triePrefetcher { + prefix := triePrefetchMetricsPrefix + namespace + p := &triePrefetcher{ + db: db, + root: root, + fetchers: make(map[string]*subfetcher), // Active prefetchers use the fetchers map + + deliveryMissMeter: metrics.GetOrRegisterMeter(prefix+"/deliverymiss", nil), + accountLoadMeter: metrics.GetOrRegisterMeter(prefix+"/account/load", nil), + accountDupMeter: metrics.GetOrRegisterMeter(prefix+"/account/dup", nil), + accountSkipMeter: metrics.GetOrRegisterMeter(prefix+"/account/skip", nil), + accountWasteMeter: metrics.GetOrRegisterMeter(prefix+"/account/waste", nil), + storageLoadMeter: metrics.GetOrRegisterMeter(prefix+"/storage/load", nil), + storageDupMeter: metrics.GetOrRegisterMeter(prefix+"/storage/dup", nil), + storageSkipMeter: metrics.GetOrRegisterMeter(prefix+"/storage/skip", nil), + storageWasteMeter: metrics.GetOrRegisterMeter(prefix+"/storage/waste", nil), + } + return p +} + +// close iterates over all the subfetchers, aborts any that were left spinning +// and reports the stats to the metrics subsystem. +func (p *triePrefetcher) close() { + for _, fetcher := range p.fetchers { + fetcher.abort() // safe to do multiple times + + if metrics.Enabled { + if fetcher.root == p.root { + p.accountLoadMeter.Mark(int64(len(fetcher.seen))) + p.accountDupMeter.Mark(int64(fetcher.dups)) + p.accountSkipMeter.Mark(int64(len(fetcher.tasks))) + + for _, key := range fetcher.used { + delete(fetcher.seen, string(key)) + } + p.accountWasteMeter.Mark(int64(len(fetcher.seen))) + } else { + p.storageLoadMeter.Mark(int64(len(fetcher.seen))) + p.storageDupMeter.Mark(int64(fetcher.dups)) + p.storageSkipMeter.Mark(int64(len(fetcher.tasks))) + + for _, key := range fetcher.used { + delete(fetcher.seen, string(key)) + } + p.storageWasteMeter.Mark(int64(len(fetcher.seen))) + } + } + } + // Clear out all fetchers (will crash on a second call, deliberate) + p.fetchers = nil +} + +// copy creates a deep-but-inactive copy of the trie prefetcher. Any trie data +// already loaded will be copied over, but no goroutines will be started. This +// is mostly used in the miner which creates a copy of it's actively mutated +// state to be sealed while it may further mutate the state. +func (p *triePrefetcher) copy() *triePrefetcher { + copy := &triePrefetcher{ + db: p.db, + root: p.root, + fetches: make(map[string]Trie), // Active prefetchers use the fetches map + + deliveryMissMeter: p.deliveryMissMeter, + accountLoadMeter: p.accountLoadMeter, + accountDupMeter: p.accountDupMeter, + accountSkipMeter: p.accountSkipMeter, + accountWasteMeter: p.accountWasteMeter, + storageLoadMeter: p.storageLoadMeter, + storageDupMeter: p.storageDupMeter, + storageSkipMeter: p.storageSkipMeter, + storageWasteMeter: p.storageWasteMeter, + } + // If the prefetcher is already a copy, duplicate the data + if p.fetches != nil { + for root, fetch := range p.fetches { + if fetch == nil { + continue + } + copy.fetches[root] = p.db.CopyTrie(fetch) + } + return copy + } + // Otherwise we're copying an active fetcher, retrieve the current states + for id, fetcher := range p.fetchers { + copy.fetches[id] = fetcher.peek() + } + return copy +} + +// prefetch schedules a batch of trie items to prefetch. +func (p *triePrefetcher) prefetch(owner common.Hash, root common.Hash, keys [][]byte) { + // If the prefetcher is an inactive one, bail out + if p.fetches != nil { + return + } + // Active fetcher, schedule the retrievals + id := p.trieID(owner, root) + fetcher := p.fetchers[id] + if fetcher == nil { + fetcher = newSubfetcher(p.db, p.root, owner, root) + p.fetchers[id] = fetcher + } + fetcher.schedule(keys) +} + +// trie returns the trie matching the root hash, or nil if the prefetcher doesn't +// have it. +func (p *triePrefetcher) trie(owner common.Hash, root common.Hash) Trie { + // If the prefetcher is inactive, return from existing deep copies + id := p.trieID(owner, root) + if p.fetches != nil { + trie := p.fetches[id] + if trie == nil { + p.deliveryMissMeter.Mark(1) + return nil + } + return p.db.CopyTrie(trie) + } + // Otherwise the prefetcher is active, bail if no trie was prefetched for this root + fetcher := p.fetchers[id] + if fetcher == nil { + p.deliveryMissMeter.Mark(1) + return nil + } + // Interrupt the prefetcher if it's by any chance still running and return + // a copy of any pre-loaded trie. + fetcher.abort() // safe to do multiple times + + trie := fetcher.peek() + if trie == nil { + p.deliveryMissMeter.Mark(1) + return nil + } + return trie +} + +// used marks a batch of state items used to allow creating statistics as to +// how useful or wasteful the prefetcher is. +func (p *triePrefetcher) used(owner common.Hash, root common.Hash, used [][]byte) { + if fetcher := p.fetchers[p.trieID(owner, root)]; fetcher != nil { + fetcher.used = used + } +} + +// trieID returns an unique trie identifier consists the trie owner and root hash. +func (p *triePrefetcher) trieID(owner common.Hash, root common.Hash) string { + return string(append(owner.Bytes(), root.Bytes()...)) +} + +// subfetcher is a trie fetcher goroutine responsible for pulling entries for a +// single trie. It is spawned when a new root is encountered and lives until the +// main prefetcher is paused and either all requested items are processed or if +// the trie being worked on is retrieved from the prefetcher. +type subfetcher struct { + db Database // Database to load trie nodes through + state common.Hash // Root hash of the state to prefetch + owner common.Hash // Owner of the trie, usually account hash + root common.Hash // Root hash of the trie to prefetch + trie Trie // Trie being populated with nodes + + tasks [][]byte // Items queued up for retrieval + lock sync.Mutex // Lock protecting the task queue + + wake chan struct{} // Wake channel if a new task is scheduled + stop chan struct{} // Channel to interrupt processing + term chan struct{} // Channel to signal interruption + copy chan chan Trie // Channel to request a copy of the current trie + + seen map[string]struct{} // Tracks the entries already loaded + dups int // Number of duplicate preload tasks + used [][]byte // Tracks the entries used in the end +} + +// newSubfetcher creates a goroutine to prefetch state items belonging to a +// particular root hash. +func newSubfetcher(db Database, state common.Hash, owner common.Hash, root common.Hash) *subfetcher { + sf := &subfetcher{ + db: db, + state: state, + owner: owner, + root: root, + wake: make(chan struct{}, 1), + stop: make(chan struct{}), + term: make(chan struct{}), + copy: make(chan chan Trie), + seen: make(map[string]struct{}), + } + go sf.loop() + return sf +} + +// schedule adds a batch of trie keys to the queue to prefetch. +func (sf *subfetcher) schedule(keys [][]byte) { + // Append the tasks to the current queue + sf.lock.Lock() + sf.tasks = append(sf.tasks, keys...) + sf.lock.Unlock() + + // Notify the prefetcher, it's fine if it's already terminated + select { + case sf.wake <- struct{}{}: + default: + } +} + +// peek tries to retrieve a deep copy of the fetcher's trie in whatever form it +// is currently. +func (sf *subfetcher) peek() Trie { + ch := make(chan Trie) + select { + case sf.copy <- ch: + // Subfetcher still alive, return copy from it + return <-ch + + case <-sf.term: + // Subfetcher already terminated, return a copy directly + if sf.trie == nil { + return nil + } + return sf.db.CopyTrie(sf.trie) + } +} + +// abort interrupts the subfetcher immediately. It is safe to call abort multiple +// times but it is not thread safe. +func (sf *subfetcher) abort() { + select { + case <-sf.stop: + default: + close(sf.stop) + } + <-sf.term +} + +// loop waits for new tasks to be scheduled and keeps loading them until it runs +// out of tasks or its underlying trie is retrieved for committing. +func (sf *subfetcher) loop() { + // No matter how the loop stops, signal anyone waiting that it's terminated + defer close(sf.term) + + // Start by opening the trie and stop processing if it fails + if sf.owner == (common.Hash{}) { + trie, err := sf.db.OpenTrie(sf.root) + if err != nil { + log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) + return + } + sf.trie = trie + } else { + trie, err := sf.db.OpenStorageTrie(sf.state, sf.owner, sf.root) + if err != nil { + log.Warn("Trie prefetcher failed opening trie", "root", sf.root, "err", err) + return + } + sf.trie = trie + } + // Trie opened successfully, keep prefetching items + for { + select { + case <-sf.wake: + // Subfetcher was woken up, retrieve any tasks to avoid spinning the lock + sf.lock.Lock() + tasks := sf.tasks + sf.tasks = nil + sf.lock.Unlock() + + // Prefetch any tasks until the loop is interrupted + for i, task := range tasks { + select { + case <-sf.stop: + // If termination is requested, add any leftover back and return + sf.lock.Lock() + sf.tasks = append(sf.tasks, tasks[i:]...) + sf.lock.Unlock() + return + + case ch := <-sf.copy: + // Somebody wants a copy of the current trie, grant them + ch <- sf.db.CopyTrie(sf.trie) + + default: + // No termination request yet, prefetch the next entry + if _, ok := sf.seen[string(task)]; ok { + sf.dups++ + } else { + sf.trie.TryGet(task) + sf.seen[string(task)] = struct{}{} + } + } + } + + case ch := <-sf.copy: + // Somebody wants a copy of the current trie, grant them + ch <- sf.db.CopyTrie(sf.trie) + + case <-sf.stop: + // Termination is requested, abort and leave remaining tasks + return + } + } +} diff --git a/trie_by_cid/state/trie_prefetcher_test.go b/trie_by_cid/state/trie_prefetcher_test.go new file mode 100644 index 0000000..cb0b67d --- /dev/null +++ b/trie_by_cid/state/trie_prefetcher_test.go @@ -0,0 +1,110 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "math/big" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" +) + +func filledStateDB() *StateDB { + state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil) + + // Create an account and check if the retrieved balance is correct + addr := common.HexToAddress("0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe") + skey := common.HexToHash("aaa") + sval := common.HexToHash("bbb") + + state.SetBalance(addr, big.NewInt(42)) // Change the account trie + state.SetCode(addr, []byte("hello")) // Change an external metadata + state.SetState(addr, skey, sval) // Change the storage trie + for i := 0; i < 100; i++ { + sk := common.BigToHash(big.NewInt(int64(i))) + state.SetState(addr, sk, sk) // Change the storage trie + } + return state +} + +func TestCopyAndClose(t *testing.T) { + db := filledStateDB() + prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") + skey := common.HexToHash("aaa") + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + time.Sleep(1 * time.Second) + a := prefetcher.trie(common.Hash{}, db.originalRoot) + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + b := prefetcher.trie(common.Hash{}, db.originalRoot) + cpy := prefetcher.copy() + cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + cpy.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + c := cpy.trie(common.Hash{}, db.originalRoot) + prefetcher.close() + cpy2 := cpy.copy() + cpy2.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + d := cpy2.trie(common.Hash{}, db.originalRoot) + cpy.close() + cpy2.close() + if a.Hash() != b.Hash() || a.Hash() != c.Hash() || a.Hash() != d.Hash() { + t.Fatalf("Invalid trie, hashes should be equal: %v %v %v %v", a.Hash(), b.Hash(), c.Hash(), d.Hash()) + } +} + +func TestUseAfterClose(t *testing.T) { + db := filledStateDB() + prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") + skey := common.HexToHash("aaa") + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + a := prefetcher.trie(common.Hash{}, db.originalRoot) + prefetcher.close() + b := prefetcher.trie(common.Hash{}, db.originalRoot) + if a == nil { + t.Fatal("Prefetching before close should not return nil") + } + if b != nil { + t.Fatal("Trie after close should return nil") + } +} + +func TestCopyClose(t *testing.T) { + db := filledStateDB() + prefetcher := newTriePrefetcher(db.db, db.originalRoot, "") + skey := common.HexToHash("aaa") + prefetcher.prefetch(common.Hash{}, db.originalRoot, [][]byte{skey.Bytes()}) + cpy := prefetcher.copy() + a := prefetcher.trie(common.Hash{}, db.originalRoot) + b := cpy.trie(common.Hash{}, db.originalRoot) + prefetcher.close() + c := prefetcher.trie(common.Hash{}, db.originalRoot) + d := cpy.trie(common.Hash{}, db.originalRoot) + if a == nil { + t.Fatal("Prefetching before close should not return nil") + } + if b == nil { + t.Fatal("Copy trie should return nil") + } + if c != nil { + t.Fatal("Trie after close should return nil") + } + if d == nil { + t.Fatal("Copy trie should not return nil") + } +} diff --git a/trie_by_cid/trie/committer.go b/trie_by_cid/trie/committer.go new file mode 100644 index 0000000..9f97887 --- /dev/null +++ b/trie_by_cid/trie/committer.go @@ -0,0 +1,196 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +// leaf represents a trie leaf node +type leaf struct { + blob []byte // raw blob of leaf + parent common.Hash // the hash of parent node +} + +// committer is the tool used for the trie Commit operation. The committer will +// capture all dirty nodes during the commit process and keep them cached in +// insertion order. +type committer struct { + nodes *NodeSet + collectLeaf bool +} + +// newCommitter creates a new committer or picks one from the pool. +func newCommitter(nodeset *NodeSet, collectLeaf bool) *committer { + return &committer{ + nodes: nodeset, + collectLeaf: collectLeaf, + } +} + +// Commit collapses a node down into a hash node. +func (c *committer) Commit(n node) hashNode { + return c.commit(nil, n).(hashNode) +} + +// commit collapses a node down into a hash node and returns it. +func (c *committer) commit(path []byte, n node) node { + // if this path is clean, use available cached data + hash, dirty := n.cache() + if hash != nil && !dirty { + return hash + } + // Commit children, then parent, and remove the dirty flag. + switch cn := n.(type) { + case *shortNode: + // Commit child + collapsed := cn.copy() + + // If the child is fullNode, recursively commit, + // otherwise it can only be hashNode or valueNode. + if _, ok := cn.Val.(*fullNode); ok { + collapsed.Val = c.commit(append(path, cn.Key...), cn.Val) + } + // The key needs to be copied, since we're adding it to the + // modified nodeset. + collapsed.Key = hexToCompact(cn.Key) + hashedNode := c.store(path, collapsed) + if hn, ok := hashedNode.(hashNode); ok { + return hn + } + return collapsed + case *fullNode: + hashedKids := c.commitChildren(path, cn) + collapsed := cn.copy() + collapsed.Children = hashedKids + + hashedNode := c.store(path, collapsed) + if hn, ok := hashedNode.(hashNode); ok { + return hn + } + return collapsed + case hashNode: + return cn + default: + // nil, valuenode shouldn't be committed + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } +} + +// commitChildren commits the children of the given fullnode +func (c *committer) commitChildren(path []byte, n *fullNode) [17]node { + var children [17]node + for i := 0; i < 16; i++ { + child := n.Children[i] + if child == nil { + continue + } + // If it's the hashed child, save the hash value directly. + // Note: it's impossible that the child in range [0, 15] + // is a valueNode. + if hn, ok := child.(hashNode); ok { + children[i] = hn + continue + } + // Commit the child recursively and store the "hashed" value. + // Note the returned node can be some embedded nodes, so it's + // possible the type is not hashNode. + children[i] = c.commit(append(path, byte(i)), child) + } + // For the 17th child, it's possible the type is valuenode. + if n.Children[16] != nil { + children[16] = n.Children[16] + } + return children +} + +// store hashes the node n and adds it to the modified nodeset. If leaf collection +// is enabled, leaf nodes will be tracked in the modified nodeset as well. +func (c *committer) store(path []byte, n node) node { + // Larger nodes are replaced by their hash and stored in the database. + var hash, _ = n.cache() + + // This was not generated - must be a small node stored in the parent. + // In theory, we should check if the node is leaf here (embedded node + // usually is leaf node). But small value (less than 32bytes) is not + // our target (leaves in account trie only). + if hash == nil { + // The node is embedded in its parent, in other words, this node + // will not be stored in the database independently, mark it as + // deleted only if the node was existent in database before. + if _, ok := c.nodes.accessList[string(path)]; ok { + c.nodes.markDeleted(path) + } + return n + } + // We have the hash already, estimate the RLP encoding-size of the node. + // The size is used for mem tracking, does not need to be exact + var ( + size = estimateSize(n) + nhash = common.BytesToHash(hash) + mnode = &memoryNode{ + hash: nhash, + node: simplifyNode(n), + size: uint16(size), + } + ) + // Collect the dirty node to nodeset for return. + c.nodes.markUpdated(path, mnode) + + // Collect the corresponding leaf node if it's required. We don't check + // full node since it's impossible to store value in fullNode. The key + // length of leaves should be exactly same. + if c.collectLeaf { + if sn, ok := n.(*shortNode); ok { + if val, ok := sn.Val.(valueNode); ok { + c.nodes.addLeaf(&leaf{blob: val, parent: nhash}) + } + } + } + return hash +} + +// estimateSize estimates the size of an rlp-encoded node, without actually +// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie +// with 1000 leaves, the only errors above 1% are on small shortnodes, where this +// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35) +func estimateSize(n node) int { + switch n := n.(type) { + case *shortNode: + // A short node contains a compacted key, and a value. + return 3 + len(n.Key) + estimateSize(n.Val) + case *fullNode: + // A full node contains up to 16 hashes (some nils), and a key + s := 3 + for i := 0; i < 16; i++ { + if child := n.Children[i]; child != nil { + s += estimateSize(child) + } else { + s++ + } + } + return s + case valueNode: + return 1 + len(n) + case hashNode: + return 1 + len(n) + default: + panic(fmt.Sprintf("node type %T", n)) + } +} diff --git a/trie_by_cid/trie/database.go b/trie_by_cid/trie/database.go index d13cb49..ac05f98 100644 --- a/trie_by_cid/trie/database.go +++ b/trie_by_cid/trie/database.go @@ -18,25 +18,50 @@ package trie import ( "errors" + "runtime" + "sync" + "time" "github.com/VictoriaMetrics/fastcache" + "github.com/cerc-io/ipld-eth-statedb/internal" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" - "github.com/ipfs/go-cid" + log "github.com/sirupsen/logrus" ) -type CidKey = cid.Cid - -func isEmpty(key CidKey) bool { - return len(key.KeyString()) == 0 -} - -// Database is an intermediate read-only layer between the trie data structures and -// the disk database. This trie Database is thread safe in providing individual, -// independent node access. +// Database is an intermediate write layer between the trie data structures and +// the disk database. The aim is to accumulate trie writes in-memory and only +// periodically flush a couple tries to disk, garbage collecting the remainder. +// +// Note, the trie Database is **not** thread safe in its mutations, but it **is** +// thread safe in providing individual, independent node access. The rationale +// behind this split design is to provide read access to RPC handlers and sync +// servers even while the trie is executing expensive garbage collection. type Database struct { - diskdb ethdb.KeyValueStore // Persistent storage for matured trie nodes - cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs + diskdb ethdb.Database // Persistent storage for matured trie nodes + + cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs + dirties map[common.Hash]*cachedNode // Data and references relationships of dirty trie nodes + oldest common.Hash // Oldest tracked node, flush-list head + newest common.Hash // Newest tracked node, flush-list tail + + gctime time.Duration // Time spent on garbage collection since last commit + gcnodes uint64 // Nodes garbage collected since last commit + gcsize common.StorageSize // Data storage garbage collected since last commit + + flushtime time.Duration // Time spent on data flushing since last commit + flushnodes uint64 // Nodes flushed since last commit + flushsize common.StorageSize // Data storage flushed since last commit + + dirtiesSize common.StorageSize // Storage size of the dirty node cache (exc. metadata) + childrenSize common.StorageSize // Storage size of the external children tracking + preimages *preimageStore // The store for caching preimages + + lock sync.RWMutex } // Config defines all necessary options for database. @@ -46,14 +71,14 @@ type Config = trie.Config // NewDatabase creates a new trie database to store ephemeral trie content before // its written out to disk or garbage collected. No read cache is created, so all // data retrievals will hit the underlying disk database. -func NewDatabase(diskdb ethdb.KeyValueStore) *Database { +func NewDatabase(diskdb ethdb.Database) *Database { return NewDatabaseWithConfig(diskdb, nil) } // NewDatabaseWithConfig creates a new trie database to store ephemeral trie content -// before it's written out to disk or garbage collected. It also acts as a read cache +// before its written out to disk or garbage collected. It also acts as a read cache // for nodes loaded from disk. -func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database { +func NewDatabaseWithConfig(diskdb ethdb.Database, config *Config) *Database { var cleans *fastcache.Cache if config != nil && config.Cache > 0 { if config.Journal == "" { @@ -62,44 +87,354 @@ func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024) } } + var preimage *preimageStore + if config != nil && config.Preimages { + preimage = newPreimageStore(diskdb) + } db := &Database{ diskdb: diskdb, cleans: cleans, + dirties: map[common.Hash]*cachedNode{{}: { + children: make(map[common.Hash]uint16), + }}, + preimages: preimage, } return db } -// DiskDB retrieves the persistent storage backing the trie database. -func (db *Database) DiskDB() ethdb.KeyValueStore { - return db.diskdb +// insert inserts a simplified trie node into the memory database. +// All nodes inserted by this function will be reference tracked +// and in theory should only used for **trie nodes** insertion. +func (db *Database) insert(hash common.Hash, size int, node node) { + // If the node's already cached, skip + if _, ok := db.dirties[hash]; ok { + return + } + memcacheDirtyWriteMeter.Mark(int64(size)) + + // Create the cached entry for this node + entry := &cachedNode{ + node: node, + size: uint16(size), + flushPrev: db.newest, + } + entry.forChilds(func(child common.Hash) { + if c := db.dirties[child]; c != nil { + c.parents++ + } + }) + db.dirties[hash] = entry + + // Update the flush-list endpoints + if db.oldest == (common.Hash{}) { + db.oldest, db.newest = hash, hash + } else { + db.dirties[db.newest].flushNext, db.newest = hash, hash + } + db.dirtiesSize += common.StorageSize(common.HashLength + entry.size) } -// Node retrieves an encoded trie node by CID. If it cannot be found -// cached in memory, it queries the persistent database. -func (db *Database) Node(key CidKey) ([]byte, error) { +// Node retrieves an encoded cached trie node from memory. If it cannot be found +// cached, the method queries the persistent database for the content. +func (db *Database) Node(hash common.Hash, codec uint64) ([]byte, error) { // It doesn't make sense to retrieve the metaroot - if isEmpty(key) { + if hash == (common.Hash{}) { return nil, errors.New("not found") } - cidbytes := key.Bytes() // Retrieve the node from the clean cache if available if db.cleans != nil { - if enc := db.cleans.Get(nil, cidbytes); enc != nil { + if enc := db.cleans.Get(nil, hash[:]); enc != nil { + memcacheCleanHitMeter.Mark(1) + memcacheCleanReadMeter.Mark(int64(len(enc))) return enc, nil } } + // Retrieve the node from the dirty cache if available + db.lock.RLock() + dirty := db.dirties[hash] + db.lock.RUnlock() + + if dirty != nil { + memcacheDirtyHitMeter.Mark(1) + memcacheDirtyReadMeter.Mark(int64(dirty.size)) + return dirty.rlp(), nil + } + memcacheDirtyMissMeter.Mark(1) // Content unavailable in memory, attempt to retrieve from disk - enc, err := db.diskdb.Get(cidbytes) + cid, err := internal.Keccak256ToCid(codec, hash[:]) + if err != nil { + return nil, err + } + enc, err := db.diskdb.Get(cid.Bytes()) if err != nil { return nil, err } - if len(enc) != 0 { if db.cleans != nil { - db.cleans.Set(cidbytes, enc) + db.cleans.Set(hash[:], enc) + memcacheCleanMissMeter.Mark(1) + memcacheCleanWriteMeter.Mark(int64(len(enc))) } return enc, nil } return nil, errors.New("not found") } + +// Nodes retrieves the hashes of all the nodes cached within the memory database. +// This method is extremely expensive and should only be used to validate internal +// states in test code. +func (db *Database) Nodes() []common.Hash { + db.lock.RLock() + defer db.lock.RUnlock() + + var hashes = make([]common.Hash, 0, len(db.dirties)) + for hash := range db.dirties { + if hash != (common.Hash{}) { // Special case for "root" references/nodes + hashes = append(hashes, hash) + } + } + return hashes +} + +// Reference adds a new reference from a parent node to a child node. +// This function is used to add reference between internal trie node +// and external node(e.g. storage trie root), all internal trie nodes +// are referenced together by database itself. +func (db *Database) Reference(child common.Hash, parent common.Hash) { + db.lock.Lock() + defer db.lock.Unlock() + + db.reference(child, parent) +} + +// reference is the private locked version of Reference. +func (db *Database) reference(child common.Hash, parent common.Hash) { + // If the node does not exist, it's a node pulled from disk, skip + node, ok := db.dirties[child] + if !ok { + return + } + // If the reference already exists, only duplicate for roots + if db.dirties[parent].children == nil { + db.dirties[parent].children = make(map[common.Hash]uint16) + db.childrenSize += cachedNodeChildrenSize + } else if _, ok = db.dirties[parent].children[child]; ok && parent != (common.Hash{}) { + return + } + node.parents++ + db.dirties[parent].children[child]++ + if db.dirties[parent].children[child] == 1 { + db.childrenSize += common.HashLength + 2 // uint16 counter + } +} + +// Dereference removes an existing reference from a root node. +func (db *Database) Dereference(root common.Hash) { + // Sanity check to ensure that the meta-root is not removed + if root == (common.Hash{}) { + log.Error("Attempted to dereference the trie cache meta root") + return + } + db.lock.Lock() + defer db.lock.Unlock() + + nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now() + db.dereference(root, common.Hash{}) + + db.gcnodes += uint64(nodes - len(db.dirties)) + db.gcsize += storage - db.dirtiesSize + db.gctime += time.Since(start) + + memcacheGCTimeTimer.Update(time.Since(start)) + memcacheGCSizeMeter.Mark(int64(storage - db.dirtiesSize)) + memcacheGCNodesMeter.Mark(int64(nodes - len(db.dirties))) + + log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.dirties), "size", storage-db.dirtiesSize, "time", time.Since(start), + "gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.dirties), "livesize", db.dirtiesSize) +} + +// dereference is the private locked version of Dereference. +func (db *Database) dereference(child common.Hash, parent common.Hash) { + // Dereference the parent-child + node := db.dirties[parent] + + if node.children != nil && node.children[child] > 0 { + node.children[child]-- + if node.children[child] == 0 { + delete(node.children, child) + db.childrenSize -= (common.HashLength + 2) // uint16 counter + } + } + // If the child does not exist, it's a previously committed node. + node, ok := db.dirties[child] + if !ok { + return + } + // If there are no more references to the child, delete it and cascade + if node.parents > 0 { + // This is a special cornercase where a node loaded from disk (i.e. not in the + // memcache any more) gets reinjected as a new node (short node split into full, + // then reverted into short), causing a cached node to have no parents. That is + // no problem in itself, but don't make maxint parents out of it. + node.parents-- + } + if node.parents == 0 { + // Remove the node from the flush-list + switch child { + case db.oldest: + db.oldest = node.flushNext + db.dirties[node.flushNext].flushPrev = common.Hash{} + case db.newest: + db.newest = node.flushPrev + db.dirties[node.flushPrev].flushNext = common.Hash{} + default: + db.dirties[node.flushPrev].flushNext = node.flushNext + db.dirties[node.flushNext].flushPrev = node.flushPrev + } + // Dereference all children and delete the node + node.forChilds(func(hash common.Hash) { + db.dereference(hash, child) + }) + delete(db.dirties, child) + db.dirtiesSize -= common.StorageSize(common.HashLength + int(node.size)) + if node.children != nil { + db.childrenSize -= cachedNodeChildrenSize + } + } +} + +// Update inserts the dirty nodes in provided nodeset into database and +// link the account trie with multiple storage tries if necessary. +func (db *Database) Update(nodes *MergedNodeSet) error { + db.lock.Lock() + defer db.lock.Unlock() + + // Insert dirty nodes into the database. In the same tree, it must be + // ensured that children are inserted first, then parent so that children + // can be linked with their parent correctly. + // + // Note, the storage tries must be flushed before the account trie to + // retain the invariant that children go into the dirty cache first. + var order []common.Hash + for owner := range nodes.sets { + if owner == (common.Hash{}) { + continue + } + order = append(order, owner) + } + if _, ok := nodes.sets[common.Hash{}]; ok { + order = append(order, common.Hash{}) + } + for _, owner := range order { + subset := nodes.sets[owner] + subset.forEachWithOrder(func(path string, n *memoryNode) { + if n.isDeleted() { + return // ignore deletion + } + db.insert(n.hash, int(n.size), n.node) + }) + } + // Link up the account trie and storage trie if the node points + // to an account trie leaf. + if set, present := nodes.sets[common.Hash{}]; present { + for _, n := range set.leaves { + var account types.StateAccount + if err := rlp.DecodeBytes(n.blob, &account); err != nil { + return err + } + if account.Root != types.EmptyRootHash { + db.reference(account.Root, n.parent) + } + } + } + return nil +} + +// Size returns the current storage size of the memory cache in front of the +// persistent database layer. +func (db *Database) Size() (common.StorageSize, common.StorageSize) { + db.lock.RLock() + defer db.lock.RUnlock() + + // db.dirtiesSize only contains the useful data in the cache, but when reporting + // the total memory consumption, the maintenance metadata is also needed to be + // counted. + var metadataSize = common.StorageSize((len(db.dirties) - 1) * cachedNodeSize) + var metarootRefs = common.StorageSize(len(db.dirties[common.Hash{}].children) * (common.HashLength + 2)) + var preimageSize common.StorageSize + if db.preimages != nil { + preimageSize = db.preimages.size() + } + return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, preimageSize +} + +// GetReader retrieves a node reader belonging to the given state root. +func (db *Database) GetReader(root common.Hash, codec uint64) Reader { + return &hashReader{db: db, codec: codec} +} + +// hashReader is reader of hashDatabase which implements the Reader interface. +type hashReader struct { + db *Database + codec uint64 +} + +// Node retrieves the trie node with the given node hash. +func (reader *hashReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) { + blob, err := reader.NodeBlob(owner, path, hash) + if err != nil { + return nil, err + } + return decodeNodeUnsafe(hash[:], blob) +} + +// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash. +func (reader *hashReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) { + return reader.db.Node(hash, reader.codec) +} + +// saveCache saves clean state cache to given directory path +// using specified CPU cores. +func (db *Database) saveCache(dir string, threads int) error { + if db.cleans == nil { + return nil + } + log.Info("Writing clean trie cache to disk", "path", dir, "threads", threads) + + start := time.Now() + err := db.cleans.SaveToFileConcurrent(dir, threads) + if err != nil { + log.Error("Failed to persist clean trie cache", "error", err) + return err + } + log.Info("Persisted the clean trie cache", "path", dir, "elapsed", common.PrettyDuration(time.Since(start))) + return nil +} + +// SaveCache atomically saves fast cache data to the given dir using all +// available CPU cores. +func (db *Database) SaveCache(dir string) error { + return db.saveCache(dir, runtime.GOMAXPROCS(0)) +} + +// SaveCachePeriodically atomically saves fast cache data to the given dir with +// the specified interval. All dump operation will only use a single CPU core. +func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + db.saveCache(dir, 1) + case <-stopCh: + return + } + } +} + +// Scheme returns the node scheme used in the database. +func (db *Database) Scheme() string { + return rawdb.HashScheme +} diff --git a/trie_by_cid/trie/database_metrics.go b/trie_by_cid/trie/database_metrics.go new file mode 100644 index 0000000..55efc55 --- /dev/null +++ b/trie_by_cid/trie/database_metrics.go @@ -0,0 +1,27 @@ +package trie + +import "github.com/ethereum/go-ethereum/metrics" + +var ( + memcacheCleanHitMeter = metrics.NewRegisteredMeter("trie/memcache/clean/hit", nil) + memcacheCleanMissMeter = metrics.NewRegisteredMeter("trie/memcache/clean/miss", nil) + memcacheCleanReadMeter = metrics.NewRegisteredMeter("trie/memcache/clean/read", nil) + memcacheCleanWriteMeter = metrics.NewRegisteredMeter("trie/memcache/clean/write", nil) + + memcacheDirtyHitMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/hit", nil) + memcacheDirtyMissMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/miss", nil) + memcacheDirtyReadMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/read", nil) + memcacheDirtyWriteMeter = metrics.NewRegisteredMeter("trie/memcache/dirty/write", nil) + + memcacheFlushTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/flush/time", nil) + memcacheFlushNodesMeter = metrics.NewRegisteredMeter("trie/memcache/flush/nodes", nil) + memcacheFlushSizeMeter = metrics.NewRegisteredMeter("trie/memcache/flush/size", nil) + + memcacheGCTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/gc/time", nil) + memcacheGCNodesMeter = metrics.NewRegisteredMeter("trie/memcache/gc/nodes", nil) + memcacheGCSizeMeter = metrics.NewRegisteredMeter("trie/memcache/gc/size", nil) + + memcacheCommitTimeTimer = metrics.NewRegisteredResettingTimer("trie/memcache/commit/time", nil) + memcacheCommitNodesMeter = metrics.NewRegisteredMeter("trie/memcache/commit/nodes", nil) + memcacheCommitSizeMeter = metrics.NewRegisteredMeter("trie/memcache/commit/size", nil) +) diff --git a/trie_by_cid/trie/database_node.go b/trie_by_cid/trie/database_node.go new file mode 100644 index 0000000..58037ae --- /dev/null +++ b/trie_by_cid/trie/database_node.go @@ -0,0 +1,183 @@ +package trie + +import ( + "fmt" + "io" + "reflect" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" +) + +// rawNode is a simple binary blob used to differentiate between collapsed trie +// nodes and already encoded RLP binary blobs (while at the same time store them +// in the same cache fields). +type rawNode []byte + +func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +func (n rawNode) EncodeRLP(w io.Writer) error { + _, err := w.Write(n) + return err +} + +// rawFullNode represents only the useful data content of a full node, with the +// caches and flags stripped out to minimize its data storage. This type honors +// the same RLP encoding as the original parent. +type rawFullNode [17]node + +func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +func (n rawFullNode) EncodeRLP(w io.Writer) error { + eb := rlp.NewEncoderBuffer(w) + n.encode(eb) + return eb.Flush() +} + +// rawShortNode represents only the useful data content of a short node, with the +// caches and flags stripped out to minimize its data storage. This type honors +// the same RLP encoding as the original parent. +type rawShortNode struct { + Key []byte + Val node +} + +func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") } +func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") } + +// cachedNode is all the information we know about a single cached trie node +// in the memory database write layer. +type cachedNode struct { + node node // Cached collapsed trie node, or raw rlp data + size uint16 // Byte size of the useful cached data + + parents uint32 // Number of live nodes referencing this one + children map[common.Hash]uint16 // External children referenced by this node + + flushPrev common.Hash // Previous node in the flush-list + flushNext common.Hash // Next node in the flush-list +} + +// cachedNodeSize is the raw size of a cachedNode data structure without any +// node data included. It's an approximate size, but should be a lot better +// than not counting them. +var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size()) + +// cachedNodeChildrenSize is the raw size of an initialized but empty external +// reference map. +const cachedNodeChildrenSize = 48 + +// rlp returns the raw rlp encoded blob of the cached trie node, either directly +// from the cache, or by regenerating it from the collapsed node. +func (n *cachedNode) rlp() []byte { + if node, ok := n.node.(rawNode); ok { + return node + } + return nodeToBytes(n.node) +} + +// obj returns the decoded and expanded trie node, either directly from the cache, +// or by regenerating it from the rlp encoded blob. +func (n *cachedNode) obj(hash common.Hash) node { + if node, ok := n.node.(rawNode); ok { + // The raw-blob format nodes are loaded either from the + // clean cache or the database, they are all in their own + // copy and safe to use unsafe decoder. + return mustDecodeNodeUnsafe(hash[:], node) + } + return expandNode(hash[:], n.node) +} + +// forChilds invokes the callback for all the tracked children of this node, +// both the implicit ones from inside the node as well as the explicit ones +// from outside the node. +func (n *cachedNode) forChilds(onChild func(hash common.Hash)) { + for child := range n.children { + onChild(child) + } + if _, ok := n.node.(rawNode); !ok { + forGatherChildren(n.node, onChild) + } +} + +// forGatherChildren traverses the node hierarchy of a collapsed storage node and +// invokes the callback for all the hashnode children. +func forGatherChildren(n node, onChild func(hash common.Hash)) { + switch n := n.(type) { + case *rawShortNode: + forGatherChildren(n.Val, onChild) + case rawFullNode: + for i := 0; i < 16; i++ { + forGatherChildren(n[i], onChild) + } + case hashNode: + onChild(common.BytesToHash(n)) + case valueNode, nil, rawNode: + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +// simplifyNode traverses the hierarchy of an expanded memory node and discards +// all the internal caches, returning a node that only contains the raw data. +func simplifyNode(n node) node { + switch n := n.(type) { + case *shortNode: + // Short nodes discard the flags and cascade + return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)} + + case *fullNode: + // Full nodes discard the flags and cascade + node := rawFullNode(n.Children) + for i := 0; i < len(node); i++ { + if node[i] != nil { + node[i] = simplifyNode(node[i]) + } + } + return node + + case valueNode, hashNode, rawNode: + return n + + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} + +// expandNode traverses the node hierarchy of a collapsed storage node and converts +// all fields and keys into expanded memory form. +func expandNode(hash hashNode, n node) node { + switch n := n.(type) { + case *rawShortNode: + // Short nodes need key and child expansion + return &shortNode{ + Key: compactToHex(n.Key), + Val: expandNode(nil, n.Val), + flags: nodeFlag{ + hash: hash, + }, + } + + case rawFullNode: + // Full nodes need child expansion + node := &fullNode{ + flags: nodeFlag{ + hash: hash, + }, + } + for i := 0; i < len(node.Children); i++ { + if n[i] != nil { + node.Children[i] = expandNode(nil, n[i]) + } + } + return node + + case valueNode, hashNode: + return n + + default: + panic(fmt.Sprintf("unknown node type: %T", n)) + } +} diff --git a/trie_by_cid/trie/database_test.go b/trie_by_cid/trie/database_test.go index a800135..5a6e36d 100644 --- a/trie_by_cid/trie/database_test.go +++ b/trie_by_cid/trie/database_test.go @@ -14,21 +14,20 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( "testing" - "github.com/ethereum/go-ethereum/ethdb/memorydb" - - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" ) // Tests that the trie database returns a missing trie node error if attempting // to retrieve the meta root. func TestDatabaseMetarootFetch(t *testing.T) { - db := trie.NewDatabase(memorydb.New()) - if _, err := db.Node(trie.CidKey{}); err == nil { + db := NewDatabase(rawdb.NewMemoryDatabase()) + if _, err := db.Node(common.Hash{}, StateTrieCodec); err == nil { t.Fatalf("metaroot retrieval succeeded") } } diff --git a/trie_by_cid/trie/encoding.go b/trie_by_cid/trie/encoding.go index 381883a..ace4570 100644 --- a/trie_by_cid/trie/encoding.go +++ b/trie_by_cid/trie/encoding.go @@ -16,8 +16,81 @@ package trie +// Trie keys are dealt with in three distinct encodings: +// +// KEYBYTES encoding contains the actual key and nothing else. This encoding is the +// input to most API functions. +// +// HEX encoding contains one byte for each nibble of the key and an optional trailing +// 'terminator' byte of value 0x10 which indicates whether or not the node at the key +// contains a value. Hex key encoding is used for nodes loaded in memory because it's +// convenient to access. +// +// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix +// encoding" there) and contains the bytes of the key and a flag. The high nibble of the +// first byte contains the flag; the lowest bit encoding the oddness of the length and +// the second-lowest encoding whether the node at the key is a value node. The low nibble +// of the first byte is zero in the case of an even number of nibbles and the first nibble +// in the case of an odd number. All remaining nibbles (now an even number) fit properly +// into the remaining bytes. Compact encoding is used for nodes stored on disk. + +// HexToCompact converts a hex path to the compact encoded format +func HexToCompact(hex []byte) []byte { + return hexToCompact(hex) +} + +func hexToCompact(hex []byte) []byte { + terminator := byte(0) + if hasTerm(hex) { + terminator = 1 + hex = hex[:len(hex)-1] + } + buf := make([]byte, len(hex)/2+1) + buf[0] = terminator << 5 // the flag byte + if len(hex)&1 == 1 { + buf[0] |= 1 << 4 // odd flag + buf[0] |= hex[0] // first nibble is contained in the first byte + hex = hex[1:] + } + decodeNibbles(hex, buf[1:]) + return buf +} + +// hexToCompactInPlace places the compact key in input buffer, returning the length +// needed for the representation +func hexToCompactInPlace(hex []byte) int { + var ( + hexLen = len(hex) // length of the hex input + firstByte = byte(0) + ) + // Check if we have a terminator there + if hexLen > 0 && hex[hexLen-1] == 16 { + firstByte = 1 << 5 + hexLen-- // last part was the terminator, ignore that + } + var ( + binLen = hexLen/2 + 1 + ni = 0 // index in hex + bi = 1 // index in bin (compact) + ) + if hexLen&1 == 1 { + firstByte |= 1 << 4 // odd flag + firstByte |= hex[0] // first nibble is contained in the first byte + ni++ + } + for ; ni < hexLen; bi, ni = bi+1, ni+2 { + hex[bi] = hex[ni]<<4 | hex[ni+1] + } + hex[0] = firstByte + return binLen +} + // CompactToHex converts a compact encoded path to hex format func CompactToHex(compact []byte) []byte { + return compactToHex(compact) +} + +func compactToHex(compact []byte) []byte { if len(compact) == 0 { return compact } @@ -62,6 +135,20 @@ func decodeNibbles(nibbles []byte, bytes []byte) { } } +// prefixLen returns the length of the common prefix of a and b. +func prefixLen(a, b []byte) int { + var i, length = 0, len(a) + if len(b) < length { + length = len(b) + } + for ; i < length; i++ { + if a[i] != b[i] { + break + } + } + return i +} + // hasTerm returns whether a hex key has the terminator flag. func hasTerm(s []byte) bool { return len(s) > 0 && s[len(s)-1] == 16 diff --git a/trie_by_cid/trie/encoding_test.go b/trie_by_cid/trie/encoding_test.go new file mode 100644 index 0000000..abc1e9d --- /dev/null +++ b/trie_by_cid/trie/encoding_test.go @@ -0,0 +1,141 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + crand "crypto/rand" + "encoding/hex" + "math/rand" + "testing" +) + +func TestHexCompact(t *testing.T) { + tests := []struct{ hex, compact []byte }{ + // empty keys, with and without terminator. + {hex: []byte{}, compact: []byte{0x00}}, + {hex: []byte{16}, compact: []byte{0x20}}, + // odd length, no terminator + {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}}, + // even length, no terminator + {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}}, + // odd length, terminator + {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}}, + // even length, terminator + {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}}, + } + for _, test := range tests { + if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) { + t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact) + } + if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) { + t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex) + } + } +} + +func TestHexKeybytes(t *testing.T) { + tests := []struct{ key, hexIn, hexOut []byte }{ + {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}}, + {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}}, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6, 16}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + { + key: []byte{0x12, 0x34, 0x5}, + hexIn: []byte{1, 2, 3, 4, 0, 5, 16}, + hexOut: []byte{1, 2, 3, 4, 0, 5, 16}, + }, + { + key: []byte{0x12, 0x34, 0x56}, + hexIn: []byte{1, 2, 3, 4, 5, 6}, + hexOut: []byte{1, 2, 3, 4, 5, 6, 16}, + }, + } + for _, test := range tests { + if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) { + t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut) + } + if k := hexToKeyBytes(test.hexIn); !bytes.Equal(k, test.key) { + t.Errorf("hexToKeyBytes(%x) -> %x, want %x", test.hexIn, k, test.key) + } + } +} + +func TestHexToCompactInPlace(t *testing.T) { + for i, key := range []string{ + "00", + "060a040c0f000a090b040803010801010900080d090a0a0d0903000b10", + "10", + } { + hexBytes, _ := hex.DecodeString(key) + exp := hexToCompact(hexBytes) + sz := hexToCompactInPlace(hexBytes) + got := hexBytes[:sz] + if !bytes.Equal(exp, got) { + t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, key, got, exp) + } + } +} + +func TestHexToCompactInPlaceRandom(t *testing.T) { + for i := 0; i < 10000; i++ { + l := rand.Intn(128) + key := make([]byte, l) + crand.Read(key) + hexBytes := keybytesToHex(key) + hexOrig := []byte(string(hexBytes)) + exp := hexToCompact(hexBytes) + sz := hexToCompactInPlace(hexBytes) + got := hexBytes[:sz] + + if !bytes.Equal(exp, got) { + t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n", + key, hexOrig, got, exp) + } + } +} + +func BenchmarkHexToCompact(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + hexToCompact(testBytes) + } +} + +func BenchmarkCompactToHex(b *testing.B) { + testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/} + for i := 0; i < b.N; i++ { + compactToHex(testBytes) + } +} + +func BenchmarkKeybytesToHex(b *testing.B) { + testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} + for i := 0; i < b.N; i++ { + keybytesToHex(testBytes) + } +} + +func BenchmarkHexToKeybytes(b *testing.B) { + testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16} + for i := 0; i < b.N; i++ { + hexToKeyBytes(testBytes) + } +} diff --git a/trie_by_cid/trie/errors.go b/trie_by_cid/trie/errors.go index 3881487..afe344b 100644 --- a/trie_by_cid/trie/errors.go +++ b/trie_by_cid/trie/errors.go @@ -27,7 +27,7 @@ import ( // information necessary for retrieving the missing node. type MissingNodeError struct { Owner common.Hash // owner of the trie if it's 2-layered trie - NodeHash []byte // hash of the missing node + NodeHash common.Hash // hash of the missing node Path []byte // hex-encoded path to the missing node err error // concrete error for missing trie node } diff --git a/trie_by_cid/trie/hasher.go b/trie_by_cid/trie/hasher.go index caa80f0..e594d6d 100644 --- a/trie_by_cid/trie/hasher.go +++ b/trie_by_cid/trie/hasher.go @@ -21,7 +21,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" "golang.org/x/crypto/sha3" ) @@ -99,7 +98,7 @@ func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNo // Previously, we did copy this one. We don't seem to need to actually // do that, since we don't overwrite/reuse keys //cached.Key = common.CopyBytes(n.Key) - collapsed.Key = trie.HexToCompact(n.Key) + collapsed.Key = hexToCompact(n.Key) // Unless the child is a valuenode or hashnode, hash it switch n.Val.(type) { case *fullNode, *shortNode: @@ -171,8 +170,8 @@ func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { // // All node encoding must be done like this: // -// node.encode(h.encbuf) -// enc := h.encodedBytes() +// node.encode(h.encbuf) +// enc := h.encodedBytes() // // This convention exists because node.encode can only be inlined/escape-analyzed when // called on a concrete receiver type. diff --git a/trie_by_cid/trie/iterator.go b/trie_by_cid/trie/iterator.go index 55e7a96..de506eb 100644 --- a/trie_by_cid/trie/iterator.go +++ b/trie_by_cid/trie/iterator.go @@ -18,14 +18,26 @@ package trie import ( "bytes" + "container/heap" "errors" + "time" + + "github.com/ethereum/go-ethereum/statediff/indexer/database/metrics" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/trie" + "github.com/ethereum/go-ethereum/core/types" + gethtrie "github.com/ethereum/go-ethereum/trie" ) -// NodeIterator is a re-export of the go-ethereum interface -type NodeIterator = trie.NodeIterator +// NodeIterator is an iterator to traverse the trie pre-order. +type NodeIterator = gethtrie.NodeIterator + +// NodeResolver is used for looking up trie nodes before reaching into the real +// persistent layer. This is not mandatory, rather is an optimization for cases +// where trie nodes can be recovered from some external mechanism without reading +// from disk. In those cases, this resolver allows short circuiting accesses and +// returning them from memory. +type NodeResolver = gethtrie.NodeResolver // Iterator is a key-value trie iterator that traverses a Trie. type Iterator struct { @@ -82,7 +94,7 @@ type nodeIterator struct { path []byte // Path to the current node err error // Failure set in case of an internal error in the iterator - resolver trie.NodeResolver // Optional intermediate resolver above the disk layer + resolver NodeResolver // optional node resolver for avoiding disk hits } // errIteratorEnd is stored in nodeIterator.err when iteration is done. @@ -99,7 +111,7 @@ func (e seekError) Error() string { } func newNodeIterator(trie *Trie, start []byte) NodeIterator { - if trie.Hash() == emptyRoot { + if trie.Hash() == types.EmptyRootHash { return &nodeIterator{ trie: trie, err: errIteratorEnd, @@ -110,7 +122,7 @@ func newNodeIterator(trie *Trie, start []byte) NodeIterator { return it } -func (it *nodeIterator) AddResolver(resolver trie.NodeResolver) { +func (it *nodeIterator) AddResolver(resolver NodeResolver) { it.resolver = resolver } @@ -128,6 +140,14 @@ func (it *nodeIterator) Parent() common.Hash { return it.stack[len(it.stack)-1].parent } +func (it *nodeIterator) ParentPath() []byte { + if len(it.stack) == 0 { + return []byte{} + } + pathlen := it.stack[len(it.stack)-1].pathlen + return it.path[:pathlen] +} + func (it *nodeIterator) Leaf() bool { return hasTerm(it.path) } @@ -241,7 +261,7 @@ func (it *nodeIterator) seek(prefix []byte) error { func (it *nodeIterator) init() (*nodeIteratorState, error) { root := it.trie.Hash() state := &nodeIteratorState{node: it.trie.root, index: -1} - if root != emptyRoot { + if root != types.EmptyRootHash { state.hash = root } return state, state.resolve(it, nil) @@ -320,7 +340,12 @@ func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) { } } } - return it.trie.resolveHash(hash, path) + // Retrieve the specified node from the underlying node reader. + // it.trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. + return it.trie.reader.node(path, common.BytesToHash(hash)) } func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) { @@ -329,7 +354,12 @@ func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) return blob, nil } } - return it.trie.resolveBlob(hash, path) + // Retrieve the specified node from the underlying node reader. + // it.trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. + return it.trie.reader.nodeBlob(path, common.BytesToHash(hash)) } func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error { @@ -455,3 +485,248 @@ func (it *nodeIterator) pop() { it.stack[len(it.stack)-1] = nil it.stack = it.stack[:len(it.stack)-1] } + +func compareNodes(a, b NodeIterator) int { + if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { + return cmp + } + if a.Leaf() && !b.Leaf() { + return -1 + } else if b.Leaf() && !a.Leaf() { + return 1 + } + if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 { + return cmp + } + if a.Leaf() && b.Leaf() { + return bytes.Compare(a.LeafBlob(), b.LeafBlob()) + } + return 0 +} + +type differenceIterator struct { + a, b NodeIterator // Nodes returned are those in b - a. + eof bool // Indicates a has run out of elements + count int // Number of nodes scanned on either trie +} + +// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that +// are not in a. Returns the iterator, and a pointer to an integer recording the number +// of nodes seen. +func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) { + a.Next(true) + it := &differenceIterator{ + a: a, + b: b, + } + return it, &it.count +} + +func (it *differenceIterator) Hash() common.Hash { + return it.b.Hash() +} + +func (it *differenceIterator) Parent() common.Hash { + return it.b.Parent() +} + +func (it *differenceIterator) ParentPath() []byte { + return it.b.ParentPath() +} + +func (it *differenceIterator) Leaf() bool { + return it.b.Leaf() +} + +func (it *differenceIterator) LeafKey() []byte { + return it.b.LeafKey() +} + +func (it *differenceIterator) LeafBlob() []byte { + return it.b.LeafBlob() +} + +func (it *differenceIterator) LeafProof() [][]byte { + return it.b.LeafProof() +} + +func (it *differenceIterator) Path() []byte { + return it.b.Path() +} + +func (it *differenceIterator) NodeBlob() []byte { + return it.b.NodeBlob() +} + +func (it *differenceIterator) AddResolver(resolver NodeResolver) { + panic("not implemented") +} + +func (it *differenceIterator) Next(bool) bool { + defer metrics.UpdateDuration(time.Now(), metrics.IndexerMetrics.DifferenceIteratorNextTimer) + // Invariants: + // - We always advance at least one element in b. + // - At the start of this function, a's path is lexically greater than b's. + if !it.b.Next(true) { + return false + } + it.count++ + + if it.eof { + // a has reached eof, so we just return all elements from b + return true + } + + for { + switch compareNodes(it.a, it.b) { + case -1: + // b jumped past a; advance a + if !it.a.Next(true) { + it.eof = true + return true + } + it.count++ + case 1: + // b is before a + return true + case 0: + // a and b are identical; skip this whole subtree if the nodes have hashes + hasHash := it.a.Hash() == common.Hash{} + if !it.b.Next(hasHash) { + return false + } + it.count++ + if !it.a.Next(hasHash) { + it.eof = true + return true + } + it.count++ + } + } +} + +func (it *differenceIterator) Error() error { + if err := it.a.Error(); err != nil { + return err + } + return it.b.Error() +} + +type nodeIteratorHeap []NodeIterator + +func (h nodeIteratorHeap) Len() int { return len(h) } +func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 } +func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) } +func (h *nodeIteratorHeap) Pop() interface{} { + n := len(*h) + x := (*h)[n-1] + *h = (*h)[0 : n-1] + return x +} + +type unionIterator struct { + items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators + count int // Number of nodes scanned across all tries +} + +// NewUnionIterator constructs a NodeIterator that iterates over elements in the union +// of the provided NodeIterators. Returns the iterator, and a pointer to an integer +// recording the number of nodes visited. +func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) { + h := make(nodeIteratorHeap, len(iters)) + copy(h, iters) + heap.Init(&h) + + ui := &unionIterator{items: &h} + return ui, &ui.count +} + +func (it *unionIterator) Hash() common.Hash { + return (*it.items)[0].Hash() +} + +func (it *unionIterator) Parent() common.Hash { + return (*it.items)[0].Parent() +} + +func (it *unionIterator) ParentPath() []byte { + return (*it.items)[0].ParentPath() +} + +func (it *unionIterator) Leaf() bool { + return (*it.items)[0].Leaf() +} + +func (it *unionIterator) LeafKey() []byte { + return (*it.items)[0].LeafKey() +} + +func (it *unionIterator) LeafBlob() []byte { + return (*it.items)[0].LeafBlob() +} + +func (it *unionIterator) LeafProof() [][]byte { + return (*it.items)[0].LeafProof() +} + +func (it *unionIterator) Path() []byte { + return (*it.items)[0].Path() +} + +func (it *unionIterator) NodeBlob() []byte { + return (*it.items)[0].NodeBlob() +} + +func (it *unionIterator) AddResolver(resolver NodeResolver) { + panic("not implemented") +} + +// Next returns the next node in the union of tries being iterated over. +// +// It does this by maintaining a heap of iterators, sorted by the iteration +// order of their next elements, with one entry for each source trie. Each +// time Next() is called, it takes the least element from the heap to return, +// advancing any other iterators that also point to that same element. These +// iterators are called with descend=false, since we know that any nodes under +// these nodes will also be duplicates, found in the currently selected iterator. +// Whenever an iterator is advanced, it is pushed back into the heap if it still +// has elements remaining. +// +// In the case that descend=false - eg, we're asked to ignore all subnodes of the +// current node - we also advance any iterators in the heap that have the current +// path as a prefix. +func (it *unionIterator) Next(descend bool) bool { + if len(*it.items) == 0 { + return false + } + + // Get the next key from the union + least := heap.Pop(it.items).(NodeIterator) + + // Skip over other nodes as long as they're identical, or, if we're not descending, as + // long as they have the same prefix as the current node. + for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) { + skipped := heap.Pop(it.items).(NodeIterator) + // Skip the whole subtree if the nodes have hashes; otherwise just skip this node + if skipped.Next(skipped.Hash() == common.Hash{}) { + it.count++ + // If there are more elements, push the iterator back on the heap + heap.Push(it.items, skipped) + } + } + if least.Next(descend) { + it.count++ + heap.Push(it.items, least) + } + return len(*it.items) > 0 +} + +func (it *unionIterator) Error() error { + for i := 0; i < len(*it.items); i++ { + if err := (*it.items)[i].Error(); err != nil { + return err + } + } + return nil +} diff --git a/trie_by_cid/trie/iterator_test.go b/trie_by_cid/trie/iterator_test.go index f7f0103..8b606ef 100644 --- a/trie_by_cid/trie/iterator_test.go +++ b/trie_by_cid/trie/iterator_test.go @@ -14,28 +14,24 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( "bytes" - "context" + "encoding/binary" "fmt" + "math/rand" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethdb/memorydb" geth_trie "github.com/ethereum/go-ethereum/trie" - - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) var ( - dbConfig, _ = postgres.DefaultConfig.WithEnv() - trieConfig = trie.Config{Cache: 256} - ctx = context.Background() - - testdata0 = []kvs{ + packableTestData = []kvsi{ {"one", 1}, {"two", 2}, {"three", 3}, @@ -43,20 +39,10 @@ var ( {"five", 5}, {"ten", 10}, } - testdata1 = []kvs{ - {"barb", 0}, - {"bard", 1}, - {"bars", 2}, - {"bar", 3}, - {"fab", 4}, - {"food", 5}, - {"foos", 6}, - {"foo", 7}, - } ) func TestEmptyIterator(t *testing.T) { - trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) iter := trie.NodeIterator(nil) seen := make(map[string]struct{}) @@ -69,63 +55,163 @@ func TestEmptyIterator(t *testing.T) { } func TestIterator(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) - all, err := updateTrie(origtrie, testdata0) - if err != nil { - t.Fatal(err) + db := NewDatabase(rawdb.NewMemoryDatabase()) + trie := NewEmpty(db) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, } - // commit and index data - root := commitTrie(t, db, origtrie) - tr := indexTrie(t, edb, root) + all := make(map[string]string) + for _, val := range vals { + all[val.k] = val.v + trie.Update([]byte(val.k), []byte(val.v)) + } + root, nodes := trie.Commit(false) + db.Update(NewWithNodeSet(nodes)) - found := make(map[string]int64) - it := trie.NewIterator(tr.NodeIterator(nil)) + trie, _ = New(TrieID(root), db, StateTrieCodec) + found := make(map[string]string) + it := NewIterator(trie.NodeIterator(nil)) for it.Next() { - found[string(it.Key)] = unpackValue(it.Value) + found[string(it.Key)] = string(it.Value) } - if len(found) != len(all) { - t.Errorf("number of iterated values do not match: want %d, found %d", len(all), len(found)) - } - for k, kv := range all { - if found[k] != kv.v { - t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], kv.v) + for k, v := range all { + if found[k] != v { + t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v) } } } -func TestIteratorSeek(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) - if _, err := updateTrie(orig, testdata1); err != nil { - t.Fatal(err) +type kv struct { + k, v []byte + t bool +} + +func TestIteratorLargeData(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := make(map[string]*kv) + + for i := byte(0); i < 255; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + + it := NewIterator(trie.NodeIterator(nil)) + for it.Next() { + vals[string(it.Key)].t = true + } + + var untouched []*kv + for _, value := range vals { + if !value.t { + untouched = append(untouched, value) + } + } + + if len(untouched) > 0 { + t.Errorf("Missed %d nodes", len(untouched)) + for _, value := range untouched { + t.Error(value) + } + } +} + +// Tests that the node iterator indeed walks over the entire database contents. +func TestNodeIteratorCoverage(t *testing.T) { + db, trie, _ := makeTestTrie(t) + // Create some arbitrary test trie to iterate + + // Gather all the node hashes found by the iterator + hashes := make(map[common.Hash]struct{}) + for it := trie.NodeIterator(nil); it.Next(true); { + if it.Hash() != (common.Hash{}) { + hashes[it.Hash()] = struct{}{} + } + } + // Cross check the hashes and the database itself + for hash := range hashes { + if _, err := db.Node(hash, StateTrieCodec); err != nil { + t.Errorf("failed to retrieve reported node %x: %v", hash, err) + } + } + for hash, obj := range db.dirties { + if obj != nil && hash != (common.Hash{}) { + if _, ok := hashes[hash]; !ok { + t.Errorf("state entry not reported %x", hash) + } + } + } + it := db.diskdb.NewIterator(nil, nil) + for it.Next() { + key := it.Key() + if _, ok := hashes[common.BytesToHash(key)]; !ok { + t.Errorf("state entry not reported %x", key) + } + } + it.Release() +} + +type kvs struct{ k, v string } + +var testdata1 = []kvs{ + {"barb", "ba"}, + {"bard", "bc"}, + {"bars", "bb"}, + {"bar", "b"}, + {"fab", "z"}, + {"food", "ab"}, + {"foos", "aa"}, + {"foo", "a"}, +} + +var testdata2 = []kvs{ + {"aardvark", "c"}, + {"bar", "b"}, + {"barb", "bd"}, + {"bars", "be"}, + {"fab", "z"}, + {"foo", "a"}, + {"foos", "aa"}, + {"food", "ab"}, + {"jars", "d"}, +} + +func TestIteratorSeek(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + for _, val := range testdata1 { + trie.Update([]byte(val.k), []byte(val.v)) } - root := commitTrie(t, db, orig) - tr := indexTrie(t, edb, root) // Seek to the middle. - it := trie.NewIterator(tr.NodeIterator([]byte("fab"))) + it := NewIterator(trie.NodeIterator([]byte("fab"))) if err := checkIteratorOrder(testdata1[4:], it); err != nil { t.Fatal(err) } // Seek to a non-existent key. - it = trie.NewIterator(tr.NodeIterator([]byte("barc"))) + it = NewIterator(trie.NodeIterator([]byte("barc"))) if err := checkIteratorOrder(testdata1[1:], it); err != nil { t.Fatal(err) } // Seek beyond the end. - it = trie.NewIterator(tr.NodeIterator([]byte("z"))) + it = NewIterator(trie.NodeIterator([]byte("z"))) if err := checkIteratorOrder(nil, it); err != nil { t.Fatal(err) } } -func checkIteratorOrder(want []kvs, it *trie.Iterator) error { +func checkIteratorOrder(want []kvs, it *Iterator) error { for it.Next() { if len(want) == 0 { return fmt.Errorf("didn't expect any more values, got key %q", it.Key) @@ -141,11 +227,366 @@ func checkIteratorOrder(want []kvs, it *trie.Iterator) error { return nil } +func TestDifferenceIterator(t *testing.T) { + dba := NewDatabase(rawdb.NewMemoryDatabase()) + triea := NewEmpty(dba) + for _, val := range testdata1 { + triea.Update([]byte(val.k), []byte(val.v)) + } + rootA, nodesA := triea.Commit(false) + dba.Update(NewWithNodeSet(nodesA)) + triea, _ = New(TrieID(rootA), dba, StateTrieCodec) + + dbb := NewDatabase(rawdb.NewMemoryDatabase()) + trieb := NewEmpty(dbb) + for _, val := range testdata2 { + trieb.Update([]byte(val.k), []byte(val.v)) + } + rootB, nodesB := trieb.Commit(false) + dbb.Update(NewWithNodeSet(nodesB)) + trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec) + + found := make(map[string]string) + di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil)) + it := NewIterator(di) + for it.Next() { + found[string(it.Key)] = string(it.Value) + } + + all := []struct{ k, v string }{ + {"aardvark", "c"}, + {"barb", "bd"}, + {"bars", "be"}, + {"jars", "d"}, + } + for _, item := range all { + if found[item.k] != item.v { + t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v) + } + } + if len(found) != len(all) { + t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all)) + } +} + +func TestUnionIterator(t *testing.T) { + dba := NewDatabase(rawdb.NewMemoryDatabase()) + triea := NewEmpty(dba) + for _, val := range testdata1 { + triea.Update([]byte(val.k), []byte(val.v)) + } + rootA, nodesA := triea.Commit(false) + dba.Update(NewWithNodeSet(nodesA)) + triea, _ = New(TrieID(rootA), dba, StateTrieCodec) + + dbb := NewDatabase(rawdb.NewMemoryDatabase()) + trieb := NewEmpty(dbb) + for _, val := range testdata2 { + trieb.Update([]byte(val.k), []byte(val.v)) + } + rootB, nodesB := trieb.Commit(false) + dbb.Update(NewWithNodeSet(nodesB)) + trieb, _ = New(TrieID(rootB), dbb, StateTrieCodec) + + di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)}) + it := NewIterator(di) + + all := []struct{ k, v string }{ + {"aardvark", "c"}, + {"barb", "ba"}, + {"barb", "bd"}, + {"bard", "bc"}, + {"bars", "bb"}, + {"bars", "be"}, + {"bar", "b"}, + {"fab", "z"}, + {"food", "ab"}, + {"foos", "aa"}, + {"foo", "a"}, + {"jars", "d"}, + } + + for i, kv := range all { + if !it.Next() { + t.Errorf("Iterator ends prematurely at element %d", i) + } + if kv.k != string(it.Key) { + t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k) + } + if kv.v != string(it.Value) { + t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v) + } + } + if it.Next() { + t.Errorf("Iterator returned extra values.") + } +} + +func TestIteratorNoDups(t *testing.T) { + tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + for _, val := range testdata1 { + tr.Update([]byte(val.k), []byte(val.v)) + } + checkIteratorNoDups(t, tr.NodeIterator(nil), nil) +} + +// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes. +func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) } + +func testIteratorContinueAfterError(t *testing.T, memonly bool) { + diskdb := rawdb.NewMemoryDatabase() + triedb := NewDatabase(diskdb) + + tr := NewEmpty(triedb) + for _, val := range testdata1 { + tr.Update([]byte(val.k), []byte(val.v)) + } + _, nodes := tr.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + // if !memonly { + // triedb.Commit(tr.Hash(), false) + // } + wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil) + + var ( + diskKeys [][]byte + memKeys []common.Hash + ) + if memonly { + memKeys = triedb.Nodes() + } else { + it := diskdb.NewIterator(nil, nil) + for it.Next() { + diskKeys = append(diskKeys, it.Key()) + } + it.Release() + } + for i := 0; i < 20; i++ { + // Create trie that will load all nodes from DB. + tr, _ := New(TrieID(tr.Hash()), triedb, StateTrieCodec) + + // Remove a random node from the database. It can't be the root node + // because that one is already loaded. + var ( + rkey common.Hash + rval []byte + robj *cachedNode + ) + for { + if memonly { + rkey = memKeys[rand.Intn(len(memKeys))] + } else { + copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))]) + } + if rkey != tr.Hash() { + break + } + } + if memonly { + robj = triedb.dirties[rkey] + delete(triedb.dirties, rkey) + } else { + rval, _ = diskdb.Get(rkey[:]) + diskdb.Delete(rkey[:]) + } + // Iterate until the error is hit. + seen := make(map[string]bool) + it := tr.NodeIterator(nil) + checkIteratorNoDups(t, it, seen) + missing, ok := it.Error().(*MissingNodeError) + if !ok || missing.NodeHash != rkey { + t.Fatal("didn't hit missing node, got", it.Error()) + } + + // Add the node back and continue iteration. + if memonly { + triedb.dirties[rkey] = robj + } else { + diskdb.Put(rkey[:], rval) + } + checkIteratorNoDups(t, it, seen) + if it.Error() != nil { + t.Fatal("unexpected error", it.Error()) + } + if len(seen) != wantNodeCount { + t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount) + } + } +} + +// Similar to the test above, this one checks that failure to create nodeIterator at a +// certain key prefix behaves correctly when Next is called. The expectation is that Next +// should retry seeking before returning true for the first time. +func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) { + testIteratorContinueAfterSeekError(t, true) +} + +func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) { + // Commit test trie to db, then remove the node containing "bars". + diskdb := rawdb.NewMemoryDatabase() + triedb := NewDatabase(diskdb) + + ctr := NewEmpty(triedb) + for _, val := range testdata1 { + ctr.Update([]byte(val.k), []byte(val.v)) + } + root, nodes := ctr.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + // if !memonly { + // triedb.Commit(root, false) + // } + barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e") + var ( + barNodeBlob []byte + barNodeObj *cachedNode + ) + if memonly { + barNodeObj = triedb.dirties[barNodeHash] + delete(triedb.dirties, barNodeHash) + } else { + barNodeBlob, _ = diskdb.Get(barNodeHash[:]) + diskdb.Delete(barNodeHash[:]) + } + // Create a new iterator that seeks to "bars". Seeking can't proceed because + // the node is missing. + tr, _ := New(TrieID(root), triedb, StateTrieCodec) + it := tr.NodeIterator([]byte("bars")) + missing, ok := it.Error().(*MissingNodeError) + if !ok { + t.Fatal("want MissingNodeError, got", it.Error()) + } else if missing.NodeHash != barNodeHash { + t.Fatal("wrong node missing") + } + // Reinsert the missing node. + if memonly { + triedb.dirties[barNodeHash] = barNodeObj + } else { + diskdb.Put(barNodeHash[:], barNodeBlob) + } + // Check that iteration produces the right set of values. + if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil { + t.Fatal(err) + } +} + +func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int { + if seen == nil { + seen = make(map[string]bool) + } + for it.Next(true) { + if seen[string(it.Path())] { + t.Fatalf("iterator visited node path %x twice", it.Path()) + } + seen[string(it.Path())] = true + } + return len(seen) +} + +type loggingTrieDb struct { + *Database + getCount uint64 +} + +// GetReader retrieves a node reader belonging to the given state root. +func (db *loggingTrieDb) GetReader(root common.Hash, codec uint64) Reader { + return &loggingNodeReader{db, codec} +} + +// hashReader is reader of hashDatabase which implements the Reader interface. +type loggingNodeReader struct { + db *loggingTrieDb + codec uint64 +} + +// Node retrieves the trie node with the given node hash. +func (reader *loggingNodeReader) Node(owner common.Hash, path []byte, hash common.Hash) (node, error) { + blob, err := reader.NodeBlob(owner, path, hash) + if err != nil { + return nil, err + } + return decodeNodeUnsafe(hash[:], blob) +} + +// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash. +func (reader *loggingNodeReader) NodeBlob(_ common.Hash, _ []byte, hash common.Hash) ([]byte, error) { + reader.db.getCount++ + return reader.db.Node(hash, reader.codec) +} + +func newLoggingStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, *loggingTrieDb, error) { + logdb := &loggingTrieDb{Database: db} + trie, err := New(id, logdb, codec) + if err != nil { + return nil, nil, err + } + return &StateTrie{trie: *trie, preimages: db.preimages}, logdb, nil +} + +// makeLargeTestTrie create a sample test trie +func makeLargeTestTrie(t testing.TB) (*Database, *StateTrie, *loggingTrieDb) { + // Create an empty trie + triedb := NewDatabase(rawdb.NewDatabase(memorydb.New())) + trie, logDb, err := newLoggingStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec) + if err != nil { + t.Fatal(err) + } + + // Fill it with some arbitrary data + for i := 0; i < 10000; i++ { + key := make([]byte, 32) + val := make([]byte, 32) + binary.BigEndian.PutUint64(key, uint64(i)) + binary.BigEndian.PutUint64(val, uint64(i)) + key = crypto.Keccak256(key) + val = crypto.Keccak256(val) + trie.Update(key, val) + } + _, nodes := trie.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + // Return the generated trie + return triedb, trie, logDb +} + +// Tests that the node iterator indeed walks over the entire database contents. +func TestNodeIteratorLargeTrie(t *testing.T) { + // Create some arbitrary test trie to iterate + _, trie, logDb := makeLargeTestTrie(t) + // Do a seek operation + trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885")) + // master: 24 get operations + // this pr: 5 get operations + if have, want := logDb.getCount, uint64(5); have != want { + t.Fatalf("Wrong number of lookups during seek, have %d want %d", have, want) + } +} + func TestIteratorNodeBlob(t *testing.T) { + // var ( + // db = rawdb.NewMemoryDatabase() + // triedb = NewDatabase(db) + // trie = NewEmpty(triedb) + // ) + // vals := []struct{ k, v string }{ + // {"do", "verb"}, + // {"ether", "wookiedoo"}, + // {"horse", "stallion"}, + // {"shaman", "horse"}, + // {"doge", "coin"}, + // {"dog", "puppy"}, + // {"somethingveryoddindeedthis is", "myothernodedata"}, + // } + // all := make(map[string]string) + // for _, val := range vals { + // all[val.k] = val.v + // trie.Update([]byte(val.k), []byte(val.v)) + // } + // _, nodes := trie.Commit(false) + // triedb.Update(NewWithNodeSet(nodes)) + edb := rawdb.NewMemoryDatabase() db := geth_trie.NewDatabase(edb) orig := geth_trie.NewEmpty(geth_trie.NewDatabase(rawdb.NewMemoryDatabase())) - if _, err := updateTrie(orig, testdata1); err != nil { + if _, err := updateTrie(orig, packableTestData); err != nil { t.Fatal(err) } root := commitTrie(t, db, orig) @@ -167,7 +608,7 @@ func TestIteratorNodeBlob(t *testing.T) { for dbIter.Next() { got, present := found[common.BytesToHash(dbIter.Key())] if !present { - t.Fatalf("Missing trie node %v", dbIter.Key()) + t.Fatalf("Miss trie node %v", dbIter.Key()) } if !bytes.Equal(got, dbIter.Value()) { t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got) @@ -175,6 +616,6 @@ func TestIteratorNodeBlob(t *testing.T) { count += 1 } if count != len(found) { - t.Fatal("Find extra trie node via iterator") + t.Fatalf("Wrong number of trie nodes found, want %d, got %d", len(found), count) } } diff --git a/trie_by_cid/trie/node.go b/trie_by_cid/trie/node.go index c995099..6ce6551 100644 --- a/trie_by_cid/trie/node.go +++ b/trie_by_cid/trie/node.go @@ -23,7 +23,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/trie" ) var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"} @@ -157,7 +156,7 @@ func decodeShort(hash, elems []byte) (node, error) { return nil, err } flag := nodeFlag{hash: hash} - key := trie.CompactToHex(kbuf) + key := compactToHex(kbuf) if hasTerm(key) { // value node val, _, err := rlp.SplitString(rest) diff --git a/trie_by_cid/trie/node_enc.go b/trie_by_cid/trie/node_enc.go index 2d26350..cade35b 100644 --- a/trie_by_cid/trie/node_enc.go +++ b/trie_by_cid/trie/node_enc.go @@ -58,3 +58,30 @@ func (n hashNode) encode(w rlp.EncoderBuffer) { func (n valueNode) encode(w rlp.EncoderBuffer) { w.WriteBytes(n) } + +func (n rawFullNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + for _, c := range n { + if c != nil { + c.encode(w) + } else { + w.Write(rlp.EmptyString) + } + } + w.ListEnd(offset) +} + +func (n *rawShortNode) encode(w rlp.EncoderBuffer) { + offset := w.List() + w.WriteBytes(n.Key) + if n.Val != nil { + n.Val.encode(w) + } else { + w.Write(rlp.EmptyString) + } + w.ListEnd(offset) +} + +func (n rawNode) encode(w rlp.EncoderBuffer) { + w.Write(n) +} diff --git a/trie_by_cid/trie/node_test.go b/trie_by_cid/trie/node_test.go index ac1d8fb..9b8b337 100644 --- a/trie_by_cid/trie/node_test.go +++ b/trie_by_cid/trie/node_test.go @@ -20,6 +20,7 @@ import ( "bytes" "testing" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" ) @@ -92,3 +93,123 @@ func TestDecodeFullNode(t *testing.T) { t.Fatalf("decode full node err: %v", err) } } + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkEncodeShortNode +// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op +func BenchmarkEncodeShortNode(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + nodeToBytes(node) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkEncodeFullNode +// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op +func BenchmarkEncodeFullNode(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + nodeToBytes(node) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeShortNode +// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op +func BenchmarkDecodeShortNode(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNode(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeShortNodeUnsafe +// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op +func BenchmarkDecodeShortNodeUnsafe(b *testing.B) { + node := &shortNode{ + Key: []byte{0x1, 0x2}, + Val: hashNode(randBytes(32)), + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNodeUnsafe(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeFullNode +// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op +func BenchmarkDecodeFullNode(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNode(hash, blob) + } +} + +// goos: darwin +// goarch: arm64 +// pkg: github.com/ethereum/go-ethereum/trie +// BenchmarkDecodeFullNodeUnsafe +// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op +func BenchmarkDecodeFullNodeUnsafe(b *testing.B) { + node := &fullNode{} + for i := 0; i < 16; i++ { + node.Children[i] = hashNode(randBytes(32)) + } + blob := nodeToBytes(node) + hash := crypto.Keccak256(blob) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + mustDecodeNodeUnsafe(hash, blob) + } +} diff --git a/trie_by_cid/trie/nodeset.go b/trie_by_cid/trie/nodeset.go new file mode 100644 index 0000000..99e4a80 --- /dev/null +++ b/trie_by_cid/trie/nodeset.go @@ -0,0 +1,218 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "fmt" + "reflect" + "sort" + "strings" + + "github.com/ethereum/go-ethereum/common" +) + +// memoryNode is all the information we know about a single cached trie node +// in the memory. +type memoryNode struct { + hash common.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes + size uint16 // Byte size of the useful cached data, 0 for deleted nodes + node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes +} + +// memoryNodeSize is the raw size of a memoryNode data structure without any +// node data included. It's an approximate size, but should be a lot better +// than not counting them. +// nolint:unused +var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size()) + +// memorySize returns the total memory size used by this node. +// nolint:unused +func (n *memoryNode) memorySize(pathlen int) int { + return int(n.size) + memoryNodeSize + pathlen +} + +// rlp returns the raw rlp encoded blob of the cached trie node, either directly +// from the cache, or by regenerating it from the collapsed node. +// nolint:unused +func (n *memoryNode) rlp() []byte { + if node, ok := n.node.(rawNode); ok { + return node + } + return nodeToBytes(n.node) +} + +// obj returns the decoded and expanded trie node, either directly from the cache, +// or by regenerating it from the rlp encoded blob. +// nolint:unused +func (n *memoryNode) obj() node { + if node, ok := n.node.(rawNode); ok { + return mustDecodeNode(n.hash[:], node) + } + return expandNode(n.hash[:], n.node) +} + +// isDeleted returns the indicator if the node is marked as deleted. +func (n *memoryNode) isDeleted() bool { + return n.hash == (common.Hash{}) +} + +// nodeWithPrev wraps the memoryNode with the previous node value. +// nolint: unused +type nodeWithPrev struct { + *memoryNode + prev []byte // RLP-encoded previous value, nil means it's non-existent +} + +// unwrap returns the internal memoryNode object. +// nolint:unused +func (n *nodeWithPrev) unwrap() *memoryNode { + return n.memoryNode +} + +// memorySize returns the total memory size used by this node. It overloads +// the function in memoryNode by counting the size of previous value as well. +// nolint: unused +func (n *nodeWithPrev) memorySize(pathlen int) int { + return n.memoryNode.memorySize(pathlen) + len(n.prev) +} + +// NodeSet contains all dirty nodes collected during the commit operation. +// Each node is keyed by path. It's not thread-safe to use. +type NodeSet struct { + owner common.Hash // the identifier of the trie + nodes map[string]*memoryNode // the set of dirty nodes(inserted, updated, deleted) + leaves []*leaf // the list of dirty leaves + updates int // the count of updated and inserted nodes + deletes int // the count of deleted nodes + + // The list of accessed nodes, which records the original node value. + // The origin value is expected to be nil for newly inserted node + // and is expected to be non-nil for other types(updated, deleted). + accessList map[string][]byte +} + +// NewNodeSet initializes an empty node set to be used for tracking dirty nodes +// from a specific account or storage trie. The owner is zero for the account +// trie and the owning account address hash for storage tries. +func NewNodeSet(owner common.Hash, accessList map[string][]byte) *NodeSet { + return &NodeSet{ + owner: owner, + nodes: make(map[string]*memoryNode), + accessList: accessList, + } +} + +// forEachWithOrder iterates the dirty nodes with the order from bottom to top, +// right to left, nodes with the longest path will be iterated first. +func (set *NodeSet) forEachWithOrder(callback func(path string, n *memoryNode)) { + var paths sort.StringSlice + for path := range set.nodes { + paths = append(paths, path) + } + // Bottom-up, longest path first + sort.Sort(sort.Reverse(paths)) + for _, path := range paths { + callback(path, set.nodes[path]) + } +} + +// markUpdated marks the node as dirty(newly-inserted or updated). +func (set *NodeSet) markUpdated(path []byte, node *memoryNode) { + set.nodes[string(path)] = node + set.updates += 1 +} + +// markDeleted marks the node as deleted. +func (set *NodeSet) markDeleted(path []byte) { + set.nodes[string(path)] = &memoryNode{} + set.deletes += 1 +} + +// addLeaf collects the provided leaf node into set. +func (set *NodeSet) addLeaf(node *leaf) { + set.leaves = append(set.leaves, node) +} + +// Size returns the number of dirty nodes in set. +func (set *NodeSet) Size() (int, int) { + return set.updates, set.deletes +} + +// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can +// we get rid of it? +func (set *NodeSet) Hashes() []common.Hash { + var ret []common.Hash + for _, node := range set.nodes { + ret = append(ret, node.hash) + } + return ret +} + +// Summary returns a string-representation of the NodeSet. +func (set *NodeSet) Summary() string { + var out = new(strings.Builder) + fmt.Fprintf(out, "nodeset owner: %v\n", set.owner) + if set.nodes != nil { + for path, n := range set.nodes { + // Deletion + if n.isDeleted() { + fmt.Fprintf(out, " [-]: %x prev: %x\n", path, set.accessList[path]) + continue + } + // Insertion + origin, ok := set.accessList[path] + if !ok { + fmt.Fprintf(out, " [+]: %x -> %v\n", path, n.hash) + continue + } + // Update + fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", path, n.hash, origin) + } + } + for _, n := range set.leaves { + fmt.Fprintf(out, "[leaf]: %v\n", n) + } + return out.String() +} + +// MergedNodeSet represents a merged dirty node set for a group of tries. +type MergedNodeSet struct { + sets map[common.Hash]*NodeSet +} + +// NewMergedNodeSet initializes an empty merged set. +func NewMergedNodeSet() *MergedNodeSet { + return &MergedNodeSet{sets: make(map[common.Hash]*NodeSet)} +} + +// NewWithNodeSet constructs a merged nodeset with the provided single set. +func NewWithNodeSet(set *NodeSet) *MergedNodeSet { + merged := NewMergedNodeSet() + merged.Merge(set) + return merged +} + +// Merge merges the provided dirty nodes of a trie into the set. The assumption +// is held that no duplicated set belonging to the same trie will be merged twice. +func (set *MergedNodeSet) Merge(other *NodeSet) error { + _, present := set.sets[other.owner] + if present { + return fmt.Errorf("duplicate trie for owner %#x", other.owner) + } + set.sets[other.owner] = other + return nil +} diff --git a/trie_by_cid/trie/preimages.go b/trie_by_cid/trie/preimages.go new file mode 100644 index 0000000..a6359ca --- /dev/null +++ b/trie_by_cid/trie/preimages.go @@ -0,0 +1,78 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "sync" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/ethdb" +) + +// preimageStore is the store for caching preimages of node key. +type preimageStore struct { + lock sync.RWMutex + disk ethdb.KeyValueStore + preimages map[common.Hash][]byte // Preimages of nodes from the secure trie + preimagesSize common.StorageSize // Storage size of the preimages cache +} + +// newPreimageStore initializes the store for caching preimages. +func newPreimageStore(disk ethdb.KeyValueStore) *preimageStore { + return &preimageStore{ + disk: disk, + preimages: make(map[common.Hash][]byte), + } +} + +// insertPreimage writes a new trie node pre-image to the memory database if it's +// yet unknown. The method will NOT make a copy of the slice, only use if the +// preimage will NOT be changed later on. +func (store *preimageStore) insertPreimage(preimages map[common.Hash][]byte) { + store.lock.Lock() + defer store.lock.Unlock() + + for hash, preimage := range preimages { + if _, ok := store.preimages[hash]; ok { + continue + } + store.preimages[hash] = preimage + store.preimagesSize += common.StorageSize(common.HashLength + len(preimage)) + } +} + +// preimage retrieves a cached trie node pre-image from memory. If it cannot be +// found cached, the method queries the persistent database for the content. +func (store *preimageStore) preimage(hash common.Hash) []byte { + store.lock.RLock() + preimage := store.preimages[hash] + store.lock.RUnlock() + + if preimage != nil { + return preimage + } + return rawdb.ReadPreimage(store.disk, hash) +} + +// size returns the current storage size of accumulated preimages. +func (store *preimageStore) size() common.StorageSize { + store.lock.RLock() + defer store.lock.RUnlock() + + return store.preimagesSize +} diff --git a/trie_by_cid/trie/proof.go b/trie_by_cid/trie/proof.go index e48eda5..7315c0d 100644 --- a/trie_by_cid/trie/proof.go +++ b/trie_by_cid/trie/proof.go @@ -20,9 +20,10 @@ import ( "bytes" "fmt" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/trie" + log "github.com/sirupsen/logrus" ) var VerifyProof = trie.VerifyProof @@ -61,10 +62,15 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e key = key[1:] nodes = append(nodes, n) case hashNode: + // Retrieve the specified node from the underlying node reader. + // trie.resolveAndTrack is not used since in that function the + // loaded blob will be tracked, while it's not required here since + // all loaded nodes won't be linked to trie at all and track nodes + // may lead to out-of-memory issue. var err error - tn, err = t.resolveHash(n, prefix) + tn, err = t.reader.node(prefix, common.BytesToHash(n)) if err != nil { - log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) + log.Error("Unhandled trie error in Trie.Prove", "err", err) return err } default: diff --git a/trie_by_cid/trie/proof_test.go b/trie_by_cid/trie/proof_test.go index 3fd508a..6b23bcd 100644 --- a/trie_by_cid/trie/proof_test.go +++ b/trie_by_cid/trie/proof_test.go @@ -14,29 +14,39 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( "bytes" + crand "crypto/rand" + "encoding/binary" + "fmt" mrand "math/rand" "sort" "testing" - "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb/memorydb" - geth_trie "github.com/ethereum/go-ethereum/trie" - - . "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) -// tweakable trie size -var scaleFactor = 512 +// Prng is a pseudo random number generator seeded by strong randomness. +// The randomness is printed on startup in order to make failures reproducible. +var prng = initRnd() -func init() { - mrand.Seed(time.Now().UnixNano()) +func initRnd() *mrand.Rand { + var seed [8]byte + crand.Read(seed[:]) + rnd := mrand.New(mrand.NewSource(int64(binary.LittleEndian.Uint64(seed[:])))) + fmt.Printf("Seed: %x\n", seed) + return rnd +} + +func randBytes(n int) []byte { + r := make([]byte, n) + prng.Read(r) + return r } // makeProvers creates Merkle trie provers based on different implementations to @@ -64,7 +74,7 @@ func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database { } func TestProof(t *testing.T) { - trie, vals := randomTrie(t, scaleFactor) + trie, vals := randomTrie(500) root := trie.Hash() for i, prover := range makeProvers(trie) { for _, kv := range vals { @@ -72,11 +82,11 @@ func TestProof(t *testing.T) { if proof == nil { t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) } - val, err := geth_trie.VerifyProof(root, kv.k, proof) + val, err := VerifyProof(root, kv.k, proof) if err != nil { t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof) } - if kv.v != unpackValue(val) { + if !bytes.Equal(val, kv.v) { t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v) } } @@ -84,13 +94,8 @@ func TestProof(t *testing.T) { } func TestOneElementProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - orig.Update([]byte("k"), packValue(42)) - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + updateString(trie, "k", "v") for i, prover := range makeProvers(trie) { proof := prover([]byte("k")) if proof == nil { @@ -99,18 +104,18 @@ func TestOneElementProof(t *testing.T) { if proof.Len() != 1 { t.Errorf("prover %d: proof should have one element", i) } - val, err := geth_trie.VerifyProof(trie.Hash(), []byte("k"), proof) + val, err := VerifyProof(trie.Hash(), []byte("k"), proof) if err != nil { t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) } - if 42 != unpackValue(val) { + if !bytes.Equal(val, []byte("v")) { t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val) } } } func TestBadProof(t *testing.T) { - trie, vals := randomTrie(t, 2*scaleFactor) + trie, vals := randomTrie(800) root := trie.Hash() for i, prover := range makeProvers(trie) { for _, kv := range vals { @@ -140,12 +145,8 @@ func TestBadProof(t *testing.T) { // Tests that missing keys can also be proven. The test explicitly uses a single // entry trie and checks for missing keys both before and after the single entry. func TestMissingKeyProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - orig.Update([]byte("k"), packValue(42)) - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + updateString(trie, "k", "v") for i, key := range []string{"a", "j", "l", "z"} { proof := memorydb.New() @@ -164,15 +165,7 @@ func TestMissingKeyProof(t *testing.T) { } } -type entry struct { - k, v []byte -} - -func packEntry(kv *kv) *entry { - return &entry{kv.k, packValue(kv.v)} -} - -type entrySlice []*entry +type entrySlice []*kv func (p entrySlice) Len() int { return len(p) } func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 } @@ -181,10 +174,10 @@ func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // TestRangeProof tests normal range proof with both edge proofs // as the existent proof. The test cases are generated randomly. func TestRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) for i := 0; i < 500; i++ { @@ -214,10 +207,10 @@ func TestRangeProof(t *testing.T) { // TestRangeProof tests normal range proof with two non-existent proofs. // The test cases are generated randomly. func TestRangeProofWithNonExistentProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) for i := 0; i < 500; i++ { @@ -286,10 +279,10 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { // - There exists a gap between the first element and the left edge proof // - There exists a gap between the last element and the right edge proof func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -343,10 +336,10 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { // element. The first edge proof can be existent one or // non-existent one. func TestOneElementRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -408,38 +401,32 @@ func TestOneElementRangeProof(t *testing.T) { } // Test the mini trie with only a single element. - t.Run("single element", func(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - entry := &entry{randBytes(32), packValue(mrand.Int63())} - orig.Update(entry.k, entry.v) - root := commitTrie(t, db, orig) - tinyTrie := indexTrie(t, edb, root) + tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + entry := &kv{randBytes(32), randBytes(20), false} + tinyTrie.Update(entry.k, entry.v) - first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() - last = entry.k - proof = memorydb.New() - if err := tinyTrie.Prove(first, 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - if err := tinyTrie.Prove(last, 0, proof); err != nil { - t.Fatalf("Failed to prove the last node %v", err) - } - _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) - if err != nil { - t.Fatalf("Expected no error, got %v", err) - } - }) + first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes() + last = entry.k + proof = memorydb.New() + if err := tinyTrie.Prove(first, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := tinyTrie.Prove(last, 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } // TestAllElementsProof tests the range proof with all elements. // The edge proofs can be nil. func TestAllElementsProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -483,86 +470,73 @@ func TestAllElementsProof(t *testing.T) { } } -func positionCases(size int) []int { - var cases []int - for _, pos := range []int{0, 1, 50, 100, 1000, 2000, size - 1} { - if pos >= size { - break - } - cases = append(cases, pos) - } - return cases -} - // TestSingleSideRangeProof tests the range starts from zero. func TestSingleSideRangeProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - var entries entrySlice - for i := 0; i < 8*scaleFactor; i++ { - value := &entry{randBytes(32), packValue(mrand.Int63())} - orig.Update(value.k, value.v) - entries = append(entries, value) - } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - sort.Sort(entries) + for i := 0; i < 64; i++ { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + var entries entrySlice + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + entries = append(entries, value) + } + sort.Sort(entries) - for _, pos := range positionCases(len(entries)) { - proof := memorydb.New() - if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - if err := trie.Prove(entries[pos].k, 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - k := make([][]byte, 0) - v := make([][]byte, 0) - for i := 0; i <= pos; i++ { - k = append(k, entries[i].k) - v = append(v, entries[i].v) - } - _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) - if err != nil { - t.Fatalf("Expected no error, got %v", err) + var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} + for _, pos := range cases { + proof := memorydb.New() + if err := trie.Prove(common.Hash{}.Bytes(), 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := 0; i <= pos; i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } } } // TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff. func TestReverseSingleSideRangeProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - var entries entrySlice - for i := 0; i < 8*scaleFactor; i++ { - value := &entry{randBytes(32), packValue(mrand.Int63())} - orig.Update(value.k, value.v) - entries = append(entries, value) - } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - sort.Sort(entries) + for i := 0; i < 64; i++ { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + var entries entrySlice + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + entries = append(entries, value) + } + sort.Sort(entries) - for _, pos := range positionCases(len(entries)) { - proof := memorydb.New() - if err := trie.Prove(entries[pos].k, 0, proof); err != nil { - t.Fatalf("Failed to prove the first node %v", err) - } - last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - if err := trie.Prove(last.Bytes(), 0, proof); err != nil { - t.Fatalf("Failed to prove the last node %v", err) - } - k := make([][]byte, 0) - v := make([][]byte, 0) - for i := pos; i < len(entries); i++ { - k = append(k, entries[i].k) - v = append(v, entries[i].v) - } - _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) - if err != nil { - t.Fatalf("Expected no error, got %v", err) + var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1} + for _, pos := range cases { + proof := memorydb.New() + if err := trie.Prove(entries[pos].k, 0, proof); err != nil { + t.Fatalf("Failed to prove the first node %v", err) + } + last := common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + if err := trie.Prove(last.Bytes(), 0, proof); err != nil { + t.Fatalf("Failed to prove the last node %v", err) + } + k := make([][]byte, 0) + v := make([][]byte, 0) + for i := pos; i < len(entries); i++ { + k = append(k, entries[i].k) + v = append(v, entries[i].v) + } + _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } } } } @@ -570,10 +544,10 @@ func TestReverseSingleSideRangeProof(t *testing.T) { // TestBadRangeProof tests a few cases which the proof is wrong. // The prover is expected to detect the error. func TestBadRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -641,17 +615,13 @@ func TestBadRangeProof(t *testing.T) { // TestGappedRangeProof focuses on the small trie with embedded nodes. // If the gapped node is embedded in the trie, it should be detected too. func TestGappedRangeProof(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) - var entries entrySlice + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + var entries []*kv // Sorted entries for i := byte(0); i < 10; i++ { - value := &entry{common.LeftPadBytes([]byte{i}, 32), packValue(int64(i))} - orig.Update(value.k, value.v) + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + trie.Update(value.k, value.v) entries = append(entries, value) } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) first, last := 2, 8 proof := memorydb.New() if err := trie.Prove(entries[first].k, 0, proof); err != nil { @@ -677,10 +647,10 @@ func TestGappedRangeProof(t *testing.T) { // TestSameSideProofs tests the element is not in the range covered by proofs func TestSameSideProofs(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -719,17 +689,13 @@ func TestSameSideProofs(t *testing.T) { } func TestHasRightElement(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) var entries entrySlice - for i := 0; i < 8*scaleFactor; i++ { - value := &entry{randBytes(32), packValue(int64(i))} - orig.Update(value.k, value.v) + for i := 0; i < 4096; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) entries = append(entries, value) } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) sort.Sort(entries) var cases = []struct { @@ -797,10 +763,10 @@ func TestHasRightElement(t *testing.T) { // TestEmptyRangeProof tests the range proof with "no" element. // The first edge proof must be a non-existent proof. func TestEmptyRangeProof(t *testing.T) { - trie, vals := randomTrie(t, 8*scaleFactor) + trie, vals := randomTrie(4096) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -827,14 +793,49 @@ func TestEmptyRangeProof(t *testing.T) { } } +// TestBloatedProof tests a malicious proof, where the proof is more or less the +// whole trie. Previously we didn't accept such packets, but the new APIs do, so +// lets leave this test as a bit weird, but present. +func TestBloatedProof(t *testing.T) { + // Use a small trie + trie, kvs := nonRandomTrie(100) + var entries entrySlice + for _, kv := range kvs { + entries = append(entries, kv) + } + sort.Sort(entries) + var keys [][]byte + var vals [][]byte + + proof := memorydb.New() + // In the 'malicious' case, we add proofs for every single item + // (but only one key/value pair used as leaf) + for i, entry := range entries { + trie.Prove(entry.k, 0, proof) + if i == 50 { + keys = append(keys, entry.k) + vals = append(vals, entry.v) + } + } + // For reference, we use the same function, but _only_ prove the first + // and last element + want := memorydb.New() + trie.Prove(keys[0], 0, want) + trie.Prove(keys[len(keys)-1], 0, want) + + if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil { + t.Fatalf("expected bloated proof to succeed, got %v", err) + } +} + // TestEmptyValueRangeProof tests normal range proof with both edge proofs // as the existent proof, but with an extra empty value included, which is a // noop technically, but practically should be rejected. func TestEmptyValueRangeProof(t *testing.T) { - trie, values := randomTrie(t, scaleFactor) + trie, values := randomTrie(512) var entries entrySlice for _, kv := range values { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -847,8 +848,8 @@ func TestEmptyValueRangeProof(t *testing.T) { break } } - noop := &entry{key, []byte{}} - entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) + noop := &kv{key, []byte{}, false} + entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...) start, end := 1, len(entries)-1 @@ -875,10 +876,10 @@ func TestEmptyValueRangeProof(t *testing.T) { // but with an extra empty value included, which is a noop technically, but // practically should be rejected. func TestAllElementsEmptyValueRangeProof(t *testing.T) { - trie, values := randomTrie(t, scaleFactor) + trie, values := randomTrie(512) var entries entrySlice for _, kv := range values { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -891,8 +892,8 @@ func TestAllElementsEmptyValueRangeProof(t *testing.T) { break } } - noop := &entry{key, nil} - entries = append(append(append(entrySlice{}, entries[:mid]...), noop), entries[mid:]...) + noop := &kv{key, []byte{}, false} + entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...) var keys [][]byte var vals [][]byte @@ -938,7 +939,7 @@ func decreaseKey(key []byte) []byte { } func BenchmarkProve(b *testing.B) { - trie, vals := randomTrie(b, 100) + trie, vals := randomTrie(100) var keys []string for k := range vals { keys = append(keys, k) @@ -955,7 +956,7 @@ func BenchmarkProve(b *testing.B) { } func BenchmarkVerifyProof(b *testing.B) { - trie, vals := randomTrie(b, 100) + trie, vals := randomTrie(100) root := trie.Hash() var keys []string var proofs []*memorydb.Database @@ -981,10 +982,10 @@ func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b, func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) } func benchmarkVerifyRangeProof(b *testing.B, size int) { - trie, vals := randomTrie(b, 8192) + trie, vals := randomTrie(8192) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -1018,10 +1019,10 @@ func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) } func benchmarkVerifyRangeNoProof(b *testing.B, size int) { - trie, vals := randomTrie(b, size) + trie, vals := randomTrie(size) var entries entrySlice for _, kv := range vals { - entries = append(entries, packEntry(kv)) + entries = append(entries, kv) } sort.Sort(entries) @@ -1040,24 +1041,56 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) { } } +func randomTrie(n int) (*Trie, map[string]*kv) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := make(map[string]*kv) + for i := byte(0); i < 100; i++ { + value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false} + trie.Update(value.k, value.v) + trie.Update(value2.k, value2.v) + vals[string(value.k)] = value + vals[string(value2.k)] = value2 + } + for i := 0; i < n; i++ { + value := &kv{randBytes(32), randBytes(20), false} + trie.Update(value.k, value.v) + vals[string(value.k)] = value + } + return trie, vals +} + +func nonRandomTrie(n int) (*Trie, map[string]*kv) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := make(map[string]*kv) + max := uint64(0xffffffffffffffff) + for i := uint64(0); i < uint64(n); i++ { + value := make([]byte, 32) + key := make([]byte, 32) + binary.LittleEndian.PutUint64(key, i) + binary.LittleEndian.PutUint64(value, i-max) + //value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} + elem := &kv{key, value, false} + trie.Update(elem.k, elem.v) + vals[string(elem.k)] = elem + } + return trie, vals +} + func TestRangeProofKeysWithSharedPrefix(t *testing.T) { keys := [][]byte{ common.Hex2Bytes("aa10000000000000000000000000000000000000000000000000000000000000"), common.Hex2Bytes("aa20000000000000000000000000000000000000000000000000000000000000"), } vals := [][]byte{ - packValue(2), - packValue(3), + common.Hex2Bytes("02"), + common.Hex2Bytes("03"), } - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig := geth_trie.NewEmpty(db) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i, key := range keys { - orig.Update(key, vals[i]) + trie.Update(key, vals[i]) } - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - + root := trie.Hash() proof := memorydb.New() start := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") end := common.Hex2Bytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") diff --git a/trie_by_cid/trie/secure_trie.go b/trie_by_cid/trie/secure_trie.go index 6e4b4ca..25f1fd7 100644 --- a/trie_by_cid/trie/secure_trie.go +++ b/trie_by_cid/trie/secure_trie.go @@ -17,82 +17,88 @@ package trie import ( - "fmt" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rlp" - "github.com/ethereum/go-ethereum/statediff/indexer/ipld" + log "github.com/sirupsen/logrus" ) -// StateTrie wraps a trie with key hashing. In a secure trie, all +// StateTrie wraps a trie with key hashing. In a stateTrie trie, all // access operations hash the key using keccak256. This prevents // calling code from creating long chains of nodes that // increase the access time. // // Contrary to a regular trie, a StateTrie can only be created with -// New and must have an attached database. +// New and must have an attached database. The database also stores +// the preimage of each key if preimage recording is enabled. // // StateTrie is not safe for concurrent use. type StateTrie struct { - trie Trie - hashKeyBuf [common.HashLength]byte + trie Trie + preimages *preimageStore + hashKeyBuf [common.HashLength]byte + secKeyCache map[string][]byte + secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch } -// NewStateTrie creates a trie with an existing root node from a backing database -// and optional intermediate in-memory node pool. +// NewStateTrie creates a trie with an existing root node from a backing database. // // If root is the zero hash or the sha3 hash of an empty string, the // trie is initially empty. Otherwise, New will panic if db is nil // and returns MissingNodeError if the root node cannot be found. -// -// Accessing the trie loads nodes from the database or node pool on demand. -// Loaded nodes are kept around until their 'cache generation' expires. -// A new cache generation is created by each call to Commit. -// cachelimit sets the number of past cache generations to keep. -// -// Retrieves IPLD blocks by CID encoded as "eth-state-trie" -func NewStateTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) { - return newStateTrie(owner, root, db, ipld.MEthStateTrie) -} - -// NewStorageTrie is identical to NewStateTrie, but retrieves IPLD blocks encoded -// as "eth-storage-trie" -func NewStorageTrie(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) { - return newStateTrie(owner, root, db, ipld.MEthStorageTrie) -} - -func newStateTrie(owner common.Hash, root common.Hash, db *Database, codec uint64) (*StateTrie, error) { +func NewStateTrie(id *ID, db *Database, codec uint64) (*StateTrie, error) { if db == nil { - panic("NewStateTrie called without a database") + panic("trie.NewStateTrie called without a database") } - trie, err := New(owner, root, db, codec) + trie, err := New(id, db, codec) if err != nil { return nil, err } - return &StateTrie{trie: *trie}, nil + return &StateTrie{trie: *trie, preimages: db.preimages}, nil +} + +// Get returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +func (t *StateTrie) Get(key []byte) []byte { + res, err := t.TryGet(key) + if err != nil { + log.Error("Unhandled trie error in StateTrie.Get", "err", err) + } + return res } // TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. -// If a node was not found in the database, a MissingNodeError is returned. +// If the specified node is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. func (t *StateTrie) TryGet(key []byte) ([]byte, error) { return t.trie.TryGet(t.hashKey(key)) } -func (t *StateTrie) TryGetAccount(key []byte) (*types.StateAccount, error) { - var ret types.StateAccount - res, err := t.TryGet(key) - if err != nil { - // log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) - panic(fmt.Sprintf("Unhandled trie error: %v", err)) - return &ret, err +// TryGetAccount attempts to retrieve an account with provided account address. +// If the specified account is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) TryGetAccount(address common.Address) (*types.StateAccount, error) { + res, err := t.trie.TryGet(t.hashKey(address.Bytes())) + if res == nil || err != nil { + return nil, err } - if res == nil { - return nil, nil + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err +} + +// TryGetAccountByHash does the same thing as TryGetAccount, however +// it expects an account hash that is the hash of address. This constitutes an +// abstraction leak, since the client code needs to know the key format. +func (t *StateTrie) TryGetAccountByHash(addrHash common.Hash) (*types.StateAccount, error) { + res, err := t.trie.TryGet(addrHash.Bytes()) + if res == nil || err != nil { + return nil, err } - err = rlp.DecodeBytes(res, &ret) - return &ret, err + ret := new(types.StateAccount) + err = rlp.DecodeBytes(res, ret) + return ret, err } // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not @@ -103,12 +109,124 @@ func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) { return t.trie.TryGetNode(path) } +// Update associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +func (t *StateTrie) Update(key, value []byte) { + if err := t.TryUpdate(key, value); err != nil { + log.Error("Unhandled trie error in StateTrie.Update", "err", err) + } +} + +// TryUpdate associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +// +// If a node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) TryUpdate(key, value []byte) error { + hk := t.hashKey(key) + err := t.trie.TryUpdate(hk, value) + if err != nil { + return err + } + t.getSecKeyCache()[string(hk)] = common.CopyBytes(key) + return nil +} + +// TryUpdateAccount account will abstract the write of an account to the +// secure trie. +func (t *StateTrie) TryUpdateAccount(address common.Address, acc *types.StateAccount) error { + hk := t.hashKey(address.Bytes()) + data, err := rlp.EncodeToBytes(acc) + if err != nil { + return err + } + if err := t.trie.TryUpdate(hk, data); err != nil { + return err + } + t.getSecKeyCache()[string(hk)] = address.Bytes() + return nil +} + +// Delete removes any existing value for key from the trie. +func (t *StateTrie) Delete(key []byte) { + if err := t.TryDelete(key); err != nil { + log.Error("Unhandled trie error in StateTrie.Delete", "err", err) + } +} + +// TryDelete removes any existing value for key from the trie. +// If the specified trie node is not in the trie, nothing will be changed. +// If a node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) TryDelete(key []byte) error { + hk := t.hashKey(key) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.TryDelete(hk) +} + +// TryDeleteAccount abstracts an account deletion from the trie. +func (t *StateTrie) TryDeleteAccount(address common.Address) error { + hk := t.hashKey(address.Bytes()) + delete(t.getSecKeyCache(), string(hk)) + return t.trie.TryDelete(hk) +} + +// GetKey returns the sha3 preimage of a hashed key that was +// previously used to store a value. +func (t *StateTrie) GetKey(shaKey []byte) []byte { + if key, ok := t.getSecKeyCache()[string(shaKey)]; ok { + return key + } + if t.preimages == nil { + return nil + } + return t.preimages.preimage(common.BytesToHash(shaKey)) +} + +// Commit collects all dirty nodes in the trie and replaces them with the +// corresponding node hash. All collected nodes (including dirty leaves if +// collectLeaf is true) will be encapsulated into a nodeset for return. +// The returned nodeset can be nil if the trie is clean (nothing to commit). +// All cached preimages will be also flushed if preimages recording is enabled. +// Once the trie is committed, it's not usable anymore. A new trie must +// be created with new root and updated trie database for following usage +func (t *StateTrie) Commit(collectLeaf bool) (common.Hash, *NodeSet) { + // Write all the pre-images to the actual disk database + if len(t.getSecKeyCache()) > 0 { + if t.preimages != nil { + preimages := make(map[common.Hash][]byte) + for hk, key := range t.secKeyCache { + preimages[common.BytesToHash([]byte(hk))] = key + } + t.preimages.insertPreimage(preimages) + } + t.secKeyCache = make(map[string][]byte) + } + // Commit the trie and return its modified nodeset. + return t.trie.Commit(collectLeaf) +} + // Hash returns the root hash of StateTrie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *StateTrie) Hash() common.Hash { return t.trie.Hash() } +// Copy returns a copy of StateTrie. +func (t *StateTrie) Copy() *StateTrie { + return &StateTrie{ + trie: *t.trie.Copy(), + preimages: t.preimages, + secKeyCache: t.secKeyCache, + } +} + // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // starts at the key after the given start key. func (t *StateTrie) NodeIterator(start []byte) NodeIterator { @@ -126,3 +244,14 @@ func (t *StateTrie) hashKey(key []byte) []byte { returnHasherToPool(h) return t.hashKeyBuf[:] } + +// getSecKeyCache returns the current secure key cache, creating a new one if +// ownership changed (i.e. the current secure trie is a copy of another owning +// the actual cache). +func (t *StateTrie) getSecKeyCache() map[string][]byte { + if t != t.secKeyCacheOwner { + t.secKeyCacheOwner = t + t.secKeyCache = make(map[string][]byte) + } + return t.secKeyCache +} diff --git a/trie_by_cid/trie/secure_trie_test.go b/trie_by_cid/trie/secure_trie_test.go new file mode 100644 index 0000000..41edbc3 --- /dev/null +++ b/trie_by_cid/trie/secure_trie_test.go @@ -0,0 +1,148 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "bytes" + "fmt" + "runtime" + "sync" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" +) + +// makeTestStateTrie creates a large enough secure trie for testing. +func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) { + // Create an empty trie + triedb := NewDatabase(rawdb.NewMemoryDatabase()) + trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec) + + // Fill it with some arbitrary data + content := make(map[string][]byte) + for i := byte(0); i < 255; i++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + // Add some other data to inflate the trie + for j := byte(3); j < 13; j++ { + key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} + content[string(key)] = val + trie.Update(key, val) + } + } + root, nodes := trie.Commit(false) + if err := triedb.Update(NewWithNodeSet(nodes)); err != nil { + panic(fmt.Errorf("failed to commit db %v", err)) + } + // Re-create the trie based on the new state + trie, _ = NewStateTrie(TrieID(root), triedb, StateTrieCodec) + return triedb, trie, content +} + +func TestSecureDelete(t *testing.T) { + trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec) + if err != nil { + t.Fatal(err) + } + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + if val.v != "" { + trie.Update([]byte(val.k), []byte(val.v)) + } else { + trie.Delete([]byte(val.k)) + } + } + hash := trie.Hash() + exp := common.HexToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") + if hash != exp { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestSecureGetKey(t *testing.T) { + trie, err := NewStateTrie(TrieID(common.Hash{}), NewDatabase(rawdb.NewMemoryDatabase()), StateTrieCodec) + if err != nil { + t.Fatal(err) + } + trie.Update([]byte("foo"), []byte("bar")) + + key := []byte("foo") + value := []byte("bar") + seckey := crypto.Keccak256(key) + + if !bytes.Equal(trie.Get(key), value) { + t.Errorf("Get did not return bar") + } + if k := trie.GetKey(seckey); !bytes.Equal(k, key) { + t.Errorf("GetKey returned %q, want %q", k, key) + } +} + +func TestStateTrieConcurrency(t *testing.T) { + // Create an initial trie and copy if for concurrent access + _, trie, _ := makeTestStateTrie() + + threads := runtime.NumCPU() + tries := make([]*StateTrie, threads) + for i := 0; i < threads; i++ { + tries[i] = trie.Copy() + } + // Start a batch of goroutines interacting with the trie + pend := new(sync.WaitGroup) + pend.Add(threads) + for i := 0; i < threads; i++ { + go func(index int) { + defer pend.Done() + + for j := byte(0); j < 255; j++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j} + tries[index].Update(key, val) + + key, val = common.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j} + tries[index].Update(key, val) + + // Add some other data to inflate the trie + for k := byte(3); k < 13; k++ { + key, val = common.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j} + tries[index].Update(key, val) + } + } + tries[index].Commit(false) + }(i) + } + // Wait for all threads to finish + pend.Wait() +} diff --git a/trie_by_cid/trie/tracer.go b/trie_by_cid/trie/tracer.go new file mode 100644 index 0000000..a27e371 --- /dev/null +++ b/trie_by_cid/trie/tracer.go @@ -0,0 +1,125 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import "github.com/ethereum/go-ethereum/common" + +// tracer tracks the changes of trie nodes. During the trie operations, +// some nodes can be deleted from the trie, while these deleted nodes +// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted +// nodes won't be removed from the disk at all. Tracer is an auxiliary tool +// used to track all insert and delete operations of trie and capture all +// deleted nodes eventually. +// +// The changed nodes can be mainly divided into two categories: the leaf +// node and intermediate node. The former is inserted/deleted by callers +// while the latter is inserted/deleted in order to follow the rule of trie. +// This tool can track all of them no matter the node is embedded in its +// parent or not, but valueNode is never tracked. +// +// Besides, it's also used for recording the original value of the nodes +// when they are resolved from the disk. The pre-value of the nodes will +// be used to construct trie history in the future. +// +// Note tracer is not thread-safe, callers should be responsible for handling +// the concurrency issues by themselves. +type tracer struct { + inserts map[string]struct{} + deletes map[string]struct{} + accessList map[string][]byte +} + +// newTracer initializes the tracer for capturing trie changes. +func newTracer() *tracer { + return &tracer{ + inserts: make(map[string]struct{}), + deletes: make(map[string]struct{}), + accessList: make(map[string][]byte), + } +} + +// onRead tracks the newly loaded trie node and caches the rlp-encoded +// blob internally. Don't change the value outside of function since +// it's not deep-copied. +func (t *tracer) onRead(path []byte, val []byte) { + t.accessList[string(path)] = val +} + +// onInsert tracks the newly inserted trie node. If it's already +// in the deletion set (resurrected node), then just wipe it from +// the deletion set as it's "untouched". +func (t *tracer) onInsert(path []byte) { + if _, present := t.deletes[string(path)]; present { + delete(t.deletes, string(path)) + return + } + t.inserts[string(path)] = struct{}{} +} + +// onDelete tracks the newly deleted trie node. If it's already +// in the addition set, then just wipe it from the addition set +// as it's untouched. +func (t *tracer) onDelete(path []byte) { + if _, present := t.inserts[string(path)]; present { + delete(t.inserts, string(path)) + return + } + t.deletes[string(path)] = struct{}{} +} + +// reset clears the content tracked by tracer. +func (t *tracer) reset() { + t.inserts = make(map[string]struct{}) + t.deletes = make(map[string]struct{}) + t.accessList = make(map[string][]byte) +} + +// copy returns a deep copied tracer instance. +func (t *tracer) copy() *tracer { + var ( + inserts = make(map[string]struct{}) + deletes = make(map[string]struct{}) + accessList = make(map[string][]byte) + ) + for path := range t.inserts { + inserts[path] = struct{}{} + } + for path := range t.deletes { + deletes[path] = struct{}{} + } + for path, blob := range t.accessList { + accessList[path] = common.CopyBytes(blob) + } + return &tracer{ + inserts: inserts, + deletes: deletes, + accessList: accessList, + } +} + +// markDeletions puts all tracked deletions into the provided nodeset. +func (t *tracer) markDeletions(set *NodeSet) { + for path := range t.deletes { + // It's possible a few deleted nodes were embedded + // in their parent before, the deletions can be no + // effect by deleting nothing, filter them out. + if _, ok := set.accessList[path]; !ok { + continue + } + set.markDeleted([]byte(path)) + } +} diff --git a/trie_by_cid/trie/trie.go b/trie_by_cid/trie/trie.go index b2d8986..5ad22c1 100644 --- a/trie_by_cid/trie/trie.go +++ b/trie_by_cid/trie/trie.go @@ -7,18 +7,15 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/core/types" + log "github.com/sirupsen/logrus" - util "github.com/cerc-io/ipld-eth-statedb/internal" "github.com/ethereum/go-ethereum/statediff/indexer/ipld" ) var ( - // emptyRoot is the known root hash of an empty trie. - emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") - - // emptyState is the known hash of an empty state trie entry. - emptyState = crypto.Keccak256Hash(nil) + StateTrieCodec uint64 = ipld.MEthStateTrie + StorageTrieCodec uint64 = ipld.MEthStorageTrie ) // Trie is a Merkle Patricia Trie. Use New to create a trie that sits on @@ -37,29 +34,48 @@ type Trie struct { // actually unhashed nodes. unhashed int - // db is the handler trie can retrieve nodes from. It's - // only for reading purpose and not available for writing. - db *Database + // reader is the handler trie can retrieve nodes from. + reader *trieReader - // Multihash codec for key encoding - codec uint64 + // tracer is the tool to track the trie changes. + // It will be reset after each commit operation. + tracer *tracer } -// New creates a trie with an existing root node from db and an assigned -// owner for storage proximity. -// -// If root is the zero hash or the sha3 hash of an empty string, the -// trie is initially empty and does not require a database. Otherwise, -// New will panic if db is nil and returns a MissingNodeError if root does -// not exist in the database. Accessing the trie loads nodes from db on demand. -func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie, error) { - trie := &Trie{ - owner: owner, - db: db, - codec: codec, +// newFlag returns the cache flag value for a newly created node. +func (t *Trie) newFlag() nodeFlag { + return nodeFlag{dirty: true} +} + +// Copy returns a copy of Trie. +func (t *Trie) Copy() *Trie { + return &Trie{ + root: t.root, + owner: t.owner, + unhashed: t.unhashed, + reader: t.reader, + tracer: t.tracer.copy(), } - if root != (common.Hash{}) && root != emptyRoot { - rootnode, err := trie.resolveHash(root[:], nil) +} + +// New creates a trie instance with the provided trie id and the read-only +// database. The state specified by trie id must be available, otherwise +// an error will be returned. The trie root specified by trie id can be +// zero hash or the sha3 hash of an empty string, then trie is initially +// empty, otherwise, the root node must be present in database or returns +// a MissingNodeError if not. +func New(id *ID, db NodeReader, codec uint64) (*Trie, error) { + reader, err := newTrieReader(id.StateRoot, id.Owner, db, codec) + if err != nil { + return nil, err + } + trie := &Trie{ + owner: id.Owner, + reader: reader, + tracer: newTracer(), + } + if id.Root != (common.Hash{}) && id.Root != types.EmptyRootHash { + rootnode, err := trie.resolveAndTrack(id.Root[:], nil) if err != nil { return nil, err } @@ -70,7 +86,10 @@ func New(owner common.Hash, root common.Hash, db *Database, codec uint64) (*Trie // NewEmpty is a shortcut to create empty tree. It's mostly used in tests. func NewEmpty(db *Database) *Trie { - tr, _ := New(common.Hash{}, common.Hash{}, db, ipld.MEthStateTrie) + tr, err := New(TrieID(common.Hash{}), db, StateTrieCodec) + if err != nil { + panic(err) + } return tr } @@ -80,6 +99,16 @@ func (t *Trie) NodeIterator(start []byte) NodeIterator { return newNodeIterator(t, start) } +// Get returns the value for key stored in the trie. +// The value bytes must not be modified by the caller. +func (t *Trie) Get(key []byte) []byte { + res, err := t.TryGet(key) + if err != nil { + log.Error("Unhandled trie error in Trie.Get", "err", err) + } + return res +} + // TryGet returns the value for key stored in the trie. // The value bytes must not be modified by the caller. // If a node was not found in the database, a MissingNodeError is returned. @@ -116,7 +145,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode } return value, n, didResolve, err case hashNode: - child, err := t.resolveHash(n, key[:pos]) + child, err := t.resolveAndTrack(n, key[:pos]) if err != nil { return nil, n, true, err } @@ -130,7 +159,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode // TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not // possible to use keybyte-encoding as the path might contain odd nibbles. func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { - item, newroot, resolved, err := t.tryGetNode(t.root, CompactToHex(path), 0) + item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0) if err != nil { return nil, resolved, err } @@ -162,11 +191,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new if hash == nil { return nil, origNode, 0, errors.New("non-consensus node") } - cid, err := util.Keccak256ToCid(t.codec, hash) - if err != nil { - return nil, origNode, 0, err - } - blob, err := t.db.Node(cid) + blob, err := t.reader.nodeBlob(path, common.BytesToHash(hash)) return blob, origNode, 1, err } // Path still needs to be traversed, descend into children @@ -196,7 +221,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new return item, n, resolved, err case hashNode: - child, err := t.resolveHash(n, path[:pos]) + child, err := t.resolveAndTrack(n, path[:pos]) if err != nil { return nil, n, 1, err } @@ -208,42 +233,309 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new } } -// resolveHash loads node from the underlying database with the provided -// node hash and path prefix. -func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { - cid, err := util.Keccak256ToCid(t.codec, n) - if err != nil { - return nil, err +// Update associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +func (t *Trie) Update(key, value []byte) { + if err := t.TryUpdate(key, value); err != nil { + log.Error("Unhandled trie error in Trie.Update", "err", err) } - enc, err := t.db.Node(cid) - if err != nil { - return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix, err: err} - } - node, err := decodeNodeUnsafe(n, enc) - if err != nil { - return nil, err - } - if node != nil { - return node, nil - } - return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix} } -// resolveHash loads rlp-encoded node blob from the underlying database -// with the provided node hash and path prefix. -func (t *Trie) resolveBlob(n hashNode, prefix []byte) ([]byte, error) { - cid, err := util.Keccak256ToCid(t.codec, n) +// TryUpdate associates key with value in the trie. Subsequent calls to +// Get will return value. If value has length zero, any existing value +// is deleted from the trie and calls to Get will return nil. +// +// The value bytes must not be modified by the caller while they are +// stored in the trie. +// +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryUpdate(key, value []byte) error { + return t.tryUpdate(key, value) +} + +// tryUpdate expects an RLP-encoded value and performs the core function +// for TryUpdate and TryUpdateAccount. +func (t *Trie) tryUpdate(key, value []byte) error { + t.unhashed++ + k := keybytesToHex(key) + if len(value) != 0 { + _, n, err := t.insert(t.root, nil, k, valueNode(value)) + if err != nil { + return err + } + t.root = n + } else { + _, n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n + } + return nil +} + +func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { + if len(key) == 0 { + if v, ok := n.(valueNode); ok { + return !bytes.Equal(v, value.(valueNode)), value, nil + } + return true, value, nil + } + switch n := n.(type) { + case *shortNode: + matchlen := prefixLen(key, n.Key) + // If the whole key matches, keep this short node as is + // and only update the value. + if matchlen == len(n.Key) { + dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) + if !dirty || err != nil { + return false, n, err + } + return true, &shortNode{n.Key, nn, t.newFlag()}, nil + } + // Otherwise branch out at the index where they differ. + branch := &fullNode{flags: t.newFlag()} + var err error + _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + if err != nil { + return false, nil, err + } + _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + if err != nil { + return false, nil, err + } + // Replace this shortNode with the branch if it occurs at index 0. + if matchlen == 0 { + return true, branch, nil + } + // New branch node is created as a child of the original short node. + // Track the newly inserted node in the tracer. The node identifier + // passed is the path from the root node. + t.tracer.onInsert(append(prefix, key[:matchlen]...)) + + // Replace it with a short node leading up to the branch. + return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil + + case *fullNode: + dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = t.newFlag() + n.Children[key[0]] = nn + return true, n, nil + + case nil: + // New short node is created and track it in the tracer. The node identifier + // passed is the path from the root node. Note the valueNode won't be tracked + // since it's always embedded in its parent. + t.tracer.onInsert(prefix) + + return true, &shortNode{key, value, t.newFlag()}, nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and insert into it. This leaves all child nodes on + // the path to the value in the trie. + rn, err := t.resolveAndTrack(n, prefix) + if err != nil { + return false, nil, err + } + dirty, nn, err := t.insert(rn, prefix, key, value) + if !dirty || err != nil { + return false, rn, err + } + return true, nn, nil + + default: + panic(fmt.Sprintf("%T: invalid node: %v", n, n)) + } +} + +// Delete removes any existing value for key from the trie. +func (t *Trie) Delete(key []byte) { + if err := t.TryDelete(key); err != nil { + log.Error("Unhandled trie error in Trie.Delete", "err", err) + } +} + +// TryDelete removes any existing value for key from the trie. +// If a node was not found in the database, a MissingNodeError is returned. +func (t *Trie) TryDelete(key []byte) error { + t.unhashed++ + k := keybytesToHex(key) + _, n, err := t.delete(t.root, nil, k) + if err != nil { + return err + } + t.root = n + return nil +} + +// delete returns the new root of the trie with key deleted. +// It reduces the trie to minimal form by simplifying +// nodes on the way up after deleting recursively. +func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { + switch n := n.(type) { + case *shortNode: + matchlen := prefixLen(key, n.Key) + if matchlen < len(n.Key) { + return false, n, nil // don't replace n on mismatch + } + if matchlen == len(key) { + // The matched short node is deleted entirely and track + // it in the deletion set. The same the valueNode doesn't + // need to be tracked at all since it's always embedded. + t.tracer.onDelete(prefix) + + return true, nil, nil // remove n entirely for whole matches + } + // The key is longer than n.Key. Remove the remaining suffix + // from the subtrie. Child can never be nil here since the + // subtrie must contain at least two other values with keys + // longer than n.Key. + dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) + if !dirty || err != nil { + return false, n, err + } + switch child := child.(type) { + case *shortNode: + // The child shortNode is merged into its parent, track + // is deleted as well. + t.tracer.onDelete(append(prefix, n.Key...)) + + // Deleting from the subtrie reduced it to another + // short node. Merge the nodes to avoid creating a + // shortNode{..., shortNode{...}}. Use concat (which + // always creates a new slice) instead of append to + // avoid modifying n.Key since it might be shared with + // other nodes. + return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil + default: + return true, &shortNode{n.Key, child, t.newFlag()}, nil + } + + case *fullNode: + dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) + if !dirty || err != nil { + return false, n, err + } + n = n.copy() + n.flags = t.newFlag() + n.Children[key[0]] = nn + + // Because n is a full node, it must've contained at least two children + // before the delete operation. If the new child value is non-nil, n still + // has at least two children after the deletion, and cannot be reduced to + // a short node. + if nn != nil { + return true, n, nil + } + // Reduction: + // Check how many non-nil entries are left after deleting and + // reduce the full node to a short node if only one entry is + // left. Since n must've contained at least two children + // before deletion (otherwise it would not be a full node) n + // can never be reduced to nil. + // + // When the loop is done, pos contains the index of the single + // value that is left in n or -2 if n contains at least two + // values. + pos := -1 + for i, cld := range &n.Children { + if cld != nil { + if pos == -1 { + pos = i + } else { + pos = -2 + break + } + } + } + if pos >= 0 { + if pos != 16 { + // If the remaining entry is a short node, it replaces + // n and its key gets the missing nibble tacked to the + // front. This avoids creating an invalid + // shortNode{..., shortNode{...}}. Since the entry + // might not be loaded yet, resolve it just for this + // check. + cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos))) + if err != nil { + return false, nil, err + } + if cnode, ok := cnode.(*shortNode); ok { + // Replace the entire full node with the short node. + // Mark the original short node as deleted since the + // value is embedded into the parent now. + t.tracer.onDelete(append(prefix, byte(pos))) + + k := append([]byte{byte(pos)}, cnode.Key...) + return true, &shortNode{k, cnode.Val, t.newFlag()}, nil + } + } + // Otherwise, n is replaced by a one-nibble short node + // containing the child. + return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil + } + // n still contains at least two values and cannot be reduced. + return true, n, nil + + case valueNode: + return true, nil, nil + + case nil: + return false, nil, nil + + case hashNode: + // We've hit a part of the trie that isn't loaded yet. Load + // the node and delete from it. This leaves all child nodes on + // the path to the value in the trie. + rn, err := t.resolveAndTrack(n, prefix) + if err != nil { + return false, nil, err + } + dirty, nn, err := t.delete(rn, prefix, key) + if !dirty || err != nil { + return false, rn, err + } + return true, nn, nil + + default: + panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) + } +} + +func concat(s1 []byte, s2 ...byte) []byte { + r := make([]byte, len(s1)+len(s2)) + copy(r, s1) + copy(r[len(s1):], s2) + return r +} + +func (t *Trie) resolve(n node, prefix []byte) (node, error) { + if n, ok := n.(hashNode); ok { + return t.resolveAndTrack(n, prefix) + } + return n, nil +} + +// resolveAndTrack loads node from the underlying store with the given node hash +// and path prefix and also tracks the loaded node blob in tracer treated as the +// node's original value. The rlp-encoded blob is preferred to be loaded from +// database because it's easy to decode node while complex to encode node to blob. +func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { + blob, err := t.reader.nodeBlob(prefix, common.BytesToHash(n)) if err != nil { return nil, err } - blob, err := t.db.Node(cid) - if err != nil { - return nil, err - } - if len(blob) != 0 { - return blob, nil - } - return nil, &MissingNodeError{Owner: t.owner, NodeHash: n, Path: prefix} + t.tracer.onRead(prefix, blob) + return mustDecodeNode(n, blob), nil } // Hash returns the root hash of the trie. It does not write to the @@ -254,15 +546,59 @@ func (t *Trie) Hash() common.Hash { return common.BytesToHash(hash.(hashNode)) } +// Commit collects all dirty nodes in the trie and replaces them with the +// corresponding node hash. All collected nodes (including dirty leaves if +// collectLeaf is true) will be encapsulated into a nodeset for return. +// The returned nodeset can be nil if the trie is clean (nothing to commit). +// Once the trie is committed, it's not usable anymore. A new trie must +// be created with new root and updated trie database for following usage +func (t *Trie) Commit(collectLeaf bool) (common.Hash, *NodeSet) { + defer t.tracer.reset() + + nodes := NewNodeSet(t.owner, t.tracer.accessList) + t.tracer.markDeletions(nodes) + + // Trie is empty and can be classified into two types of situations: + // - The trie was empty and no update happens + // - The trie was non-empty and all nodes are dropped + if t.root == nil { + return types.EmptyRootHash, nodes + } + // Derive the hash for all dirty nodes first. We hold the assumption + // in the following procedure that all nodes are hashed. + rootHash := t.Hash() + + // Do a quick check if we really need to commit. This can happen e.g. + // if we load a trie for reading storage values, but don't write to it. + if hashedNode, dirty := t.root.cache(); !dirty { + // Replace the root node with the origin hash in order to + // ensure all resolved nodes are dropped after the commit. + t.root = hashedNode + return rootHash, nil + } + t.root = newCommitter(nodes, collectLeaf).Commit(t.root) + return rootHash, nodes +} + // hashRoot calculates the root hash of the given trie func (t *Trie) hashRoot() (node, node) { if t.root == nil { - return hashNode(emptyRoot.Bytes()), nil + return hashNode(types.EmptyRootHash.Bytes()), nil } // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) - defer returnHasherToPool(h) + defer func() { + returnHasherToPool(h) + t.unhashed = 0 + }() hashed, cached := h.hash(t.root, true) - t.unhashed = 0 return hashed, cached } + +// Reset drops the referenced root node and cleans all internal state. +func (t *Trie) Reset() { + t.root = nil + t.owner = common.Hash{} + t.unhashed = 0 + t.tracer.reset() +} diff --git a/trie_by_cid/trie/trie_id.go b/trie_by_cid/trie/trie_id.go new file mode 100644 index 0000000..8ab490c --- /dev/null +++ b/trie_by_cid/trie/trie_id.go @@ -0,0 +1,55 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see + +package trie + +import "github.com/ethereum/go-ethereum/common" + +// ID is the identifier for uniquely identifying a trie. +type ID struct { + StateRoot common.Hash // The root of the corresponding state(block.root) + Owner common.Hash // The contract address hash which the trie belongs to + Root common.Hash // The root hash of trie +} + +// StateTrieID constructs an identifier for state trie with the provided state root. +func StateTrieID(root common.Hash) *ID { + return &ID{ + StateRoot: root, + Owner: common.Hash{}, + Root: root, + } +} + +// StorageTrieID constructs an identifier for storage trie which belongs to a certain +// state and contract specified by the stateRoot and owner. +func StorageTrieID(stateRoot common.Hash, owner common.Hash, root common.Hash) *ID { + return &ID{ + StateRoot: stateRoot, + Owner: owner, + Root: root, + } +} + +// TrieID constructs an identifier for a standard trie(not a second-layer trie) +// with provided root. It's mostly used in tests and some other tries like CHT trie. +func TrieID(root common.Hash) *ID { + return &ID{ + StateRoot: root, + Owner: common.Hash{}, + Root: root, + } +} diff --git a/trie_by_cid/trie/trie_reader.go b/trie_by_cid/trie/trie_reader.go new file mode 100644 index 0000000..b0a7fdd --- /dev/null +++ b/trie_by_cid/trie/trie_reader.go @@ -0,0 +1,106 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package trie + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +// Reader wraps the Node and NodeBlob method of a backing trie store. +type Reader interface { + // Node retrieves the trie node with the provided trie identifier, hexary + // node path and the corresponding node hash. + // No error will be returned if the node is not found. + Node(owner common.Hash, path []byte, hash common.Hash) (node, error) + + // NodeBlob retrieves the RLP-encoded trie node blob with the provided trie + // identifier, hexary node path and the corresponding node hash. + // No error will be returned if the node is not found. + NodeBlob(owner common.Hash, path []byte, hash common.Hash) ([]byte, error) +} + +// NodeReader wraps all the necessary functions for accessing trie node. +type NodeReader interface { + // GetReader returns a reader for accessing all trie nodes with provided + // state root. Nil is returned in case the state is not available. + GetReader(root common.Hash, codec uint64) Reader +} + +// trieReader is a wrapper of the underlying node reader. It's not safe +// for concurrent usage. +type trieReader struct { + owner common.Hash + reader Reader + banned map[string]struct{} // Marker to prevent node from being accessed, for tests +} + +// newTrieReader initializes the trie reader with the given node reader. +func newTrieReader(stateRoot, owner common.Hash, db NodeReader, codec uint64) (*trieReader, error) { + reader := db.GetReader(stateRoot, codec) + if reader == nil { + return nil, fmt.Errorf("state not found #%x", stateRoot) + } + return &trieReader{owner: owner, reader: reader}, nil +} + +// newEmptyReader initializes the pure in-memory reader. All read operations +// should be forbidden and returns the MissingNodeError. +func newEmptyReader() *trieReader { + return &trieReader{} +} + +// node retrieves the trie node with the provided trie node information. +// An MissingNodeError will be returned in case the node is not found or +// any error is encountered. +func (r *trieReader) node(path []byte, hash common.Hash) (node, error) { + // Perform the logics in tests for preventing trie node access. + if r.banned != nil { + if _, ok := r.banned[string(path)]; ok { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + } + if r.reader == nil { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + node, err := r.reader.Node(r.owner, path, hash) + if err != nil || node == nil { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err} + } + return node, nil +} + +// node retrieves the rlp-encoded trie node with the provided trie node +// information. An MissingNodeError will be returned in case the node is +// not found or any error is encountered. +func (r *trieReader) nodeBlob(path []byte, hash common.Hash) ([]byte, error) { + // Perform the logics in tests for preventing trie node access. + if r.banned != nil { + if _, ok := r.banned[string(path)]; ok { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + } + if r.reader == nil { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path} + } + blob, err := r.reader.NodeBlob(r.owner, path, hash) + if err != nil || len(blob) == 0 { + return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err} + } + return blob, nil +} diff --git a/trie_by_cid/trie/trie_test.go b/trie_by_cid/trie/trie_test.go index 0418409..76aaf72 100644 --- a/trie_by_cid/trie/trie_test.go +++ b/trie_by_cid/trie/trie_test.go @@ -14,26 +14,34 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package trie_test +package trie import ( + "bytes" + "encoding/binary" + "errors" "fmt" "math/big" "math/rand" + "reflect" "testing" + "testing/quick" + "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" - geth_trie "github.com/ethereum/go-ethereum/trie" - - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" ) -func TestTrieEmpty(t *testing.T) { - trie := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase())) +func init() { + spew.Config.Indent = " " + spew.Config.DisableMethods = false +} + +func TestEmptyTrie(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) res := trie.Hash() exp := types.EmptyRootHash if res != exp { @@ -41,42 +49,543 @@ func TestTrieEmpty(t *testing.T) { } } -func TestTrieMissingRoot(t *testing.T) { +func TestNull(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + key := make([]byte, 32) + value := []byte("test") + trie.Update(key, value) + if !bytes.Equal(trie.Get(key), value) { + t.Fatal("wrong value") + } +} + +func TestMissingRoot(t *testing.T) { root := common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33") - tr, err := newStateTrie(root, trie.NewDatabase(rawdb.NewMemoryDatabase())) - if tr != nil { + trie, err := NewAccountTrie(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase())) + if trie != nil { t.Error("New returned non-nil trie for invalid root") } - if _, ok := err.(*trie.MissingNodeError); !ok { + if _, ok := err.(*MissingNodeError); !ok { t.Errorf("New returned wrong error: %v", err) } } -func TestTrieBasic(t *testing.T) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) - origtrie.Update([]byte("foo"), packValue(842)) - expected := commitTrie(t, db, origtrie) - tr := indexTrie(t, edb, expected) - got := tr.Hash() - if expected != got { - t.Errorf("got %x expected %x", got, expected) +func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) } + +func testMissingNode(t *testing.T, memonly bool) { + diskdb := rawdb.NewMemoryDatabase() + triedb := NewDatabase(diskdb) + + trie := NewEmpty(triedb) + updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer") + updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf") + root, nodes := trie.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err := trie.TryGet([]byte("120000")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("120099")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryDelete([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9") + if memonly { + delete(triedb.dirties, hash) + } else { + diskdb.Delete(hash[:]) + } + + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("120000")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("120099")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + _, err = trie.TryGet([]byte("123456")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryUpdate([]byte("120099"), []byte("zxcv")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) + } + trie, _ = NewAccountTrie(TrieID(root), triedb) + err = trie.TryDelete([]byte("123456")) + if _, ok := err.(*MissingNodeError); !ok { + t.Errorf("Wrong error: %v", err) } - checkValue(t, tr, []byte("foo")) } -func TestTrieTiny(t *testing.T) { +func TestInsert(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + + updateString(trie, "doe", "reindeer") + updateString(trie, "dog", "puppy") + updateString(trie, "dogglesworth", "cat") + + exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") + root := trie.Hash() + if root != exp { + t.Errorf("case 1: exp %x got %x", exp, root) + } + + trie = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + + exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") + root, _ = trie.Commit(false) + if root != exp { + t.Errorf("case 2: exp %x got %x", exp, root) + } +} + +func TestGet(t *testing.T) { + db := NewDatabase(rawdb.NewMemoryDatabase()) + trie := NewEmpty(db) + updateString(trie, "doe", "reindeer") + updateString(trie, "dog", "puppy") + updateString(trie, "dogglesworth", "cat") + + for i := 0; i < 2; i++ { + res := getString(trie, "dog") + if !bytes.Equal(res, []byte("puppy")) { + t.Errorf("expected puppy got %x", res) + } + unknown := getString(trie, "unknown") + if unknown != nil { + t.Errorf("expected nil got %x", unknown) + } + if i == 1 { + return + } + root, nodes := trie.Commit(false) + db.Update(NewWithNodeSet(nodes)) + trie, _ = NewAccountTrie(TrieID(root), db) + } +} + +func TestDelete(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + if val.v != "" { + updateString(trie, val.k, val.v) + } else { + deleteString(trie, val.k) + } + } + + hash := trie.Hash() + exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if hash != exp { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestEmptyValues(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"ether", ""}, + {"dog", "puppy"}, + {"shaman", ""}, + } + for _, val := range vals { + updateString(trie, val.k, val.v) + } + + hash := trie.Hash() + exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") + if hash != exp { + t.Errorf("expected %x got %x", exp, hash) + } +} + +func TestReplication(t *testing.T) { + triedb := NewDatabase(rawdb.NewMemoryDatabase()) + trie := NewEmpty(triedb) + vals := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + {"shaman", "horse"}, + {"doge", "coin"}, + {"dog", "puppy"}, + {"somethingveryoddindeedthis is", "myothernodedata"}, + } + for _, val := range vals { + updateString(trie, val.k, val.v) + } + exp, nodes := trie.Commit(false) + triedb.Update(NewWithNodeSet(nodes)) + + // create a new trie on top of the database and check that lookups work. + trie2, err := NewAccountTrie(TrieID(exp), triedb) + if err != nil { + t.Fatalf("can't recreate trie at %x: %v", exp, err) + } + for _, kv := range vals { + if string(getString(trie2, kv.k)) != kv.v { + t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v) + } + } + hash, nodes := trie2.Commit(false) + if hash != exp { + t.Errorf("root failure. expected %x got %x", exp, hash) + } + + // recreate the trie after commit + if nodes != nil { + triedb.Update(NewWithNodeSet(nodes)) + } + trie2, err = NewAccountTrie(TrieID(hash), triedb) + if err != nil { + t.Fatalf("can't recreate trie at %x: %v", exp, err) + } + // perform some insertions on the new trie. + vals2 := []struct{ k, v string }{ + {"do", "verb"}, + {"ether", "wookiedoo"}, + {"horse", "stallion"}, + // {"shaman", "horse"}, + // {"doge", "coin"}, + // {"ether", ""}, + // {"dog", "puppy"}, + // {"somethingveryoddindeedthis is", "myothernodedata"}, + // {"shaman", ""}, + } + for _, val := range vals2 { + updateString(trie2, val.k, val.v) + } + if hash := trie2.Hash(); hash != exp { + t.Errorf("root failure. expected %x got %x", exp, hash) + } +} + +func TestLargeValue(t *testing.T) { + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + trie.Update([]byte("key1"), []byte{99, 99, 99, 99}) + trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32)) + trie.Hash() +} + +// TestRandomCases tests some cases that were found via random fuzzing +func TestRandomCases(t *testing.T) { + var rt = []randTestStep{ + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 0 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 1 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000002")}, // step 2 + {op: 2, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 3 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 4 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 5 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 6 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 7 + {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000008")}, // step 8 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000009")}, // step 9 + {op: 2, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 10 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 11 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 12 + {op: 0, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("000000000000000d")}, // step 13 + {op: 6, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 14 + {op: 1, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("")}, // step 15 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 16 + {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000011")}, // step 17 + {op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 18 + {op: 3, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 19 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000014")}, // step 20 + {op: 0, key: common.Hex2Bytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: common.Hex2Bytes("0000000000000015")}, // step 21 + {op: 0, key: common.Hex2Bytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: common.Hex2Bytes("0000000000000016")}, // step 22 + {op: 5, key: common.Hex2Bytes(""), value: common.Hex2Bytes("")}, // step 23 + {op: 1, key: common.Hex2Bytes("980c393656413a15c8da01978ed9f89feb80b502f58f2d640e3a2f5f7a99a7018f1b573befd92053ac6f78fca4a87268"), value: common.Hex2Bytes("")}, // step 24 + {op: 1, key: common.Hex2Bytes("fd"), value: common.Hex2Bytes("")}, // step 25 + } + runRandTest(rt) +} + +// randTest performs random trie operations. +// Instances of this test are created by Generate. +type randTest []randTestStep + +type randTestStep struct { + op int + key []byte // for opUpdate, opDelete, opGet + value []byte // for opUpdate + err error // for debugging +} + +const ( + opUpdate = iota + opDelete + opGet + opHash + opCommit + opItercheckhash + opNodeDiff + opProve + opMax // boundary value, not an actual op +) + +func (randTest) Generate(r *rand.Rand, size int) reflect.Value { + var allKeys [][]byte + genKey := func() []byte { + if len(allKeys) < 2 || r.Intn(100) < 10 { + // new key + key := make([]byte, r.Intn(50)) + r.Read(key) + allKeys = append(allKeys, key) + return key + } + // use existing key + return allKeys[r.Intn(len(allKeys))] + } + + var steps randTest + for i := 0; i < size; i++ { + step := randTestStep{op: r.Intn(opMax)} + switch step.op { + case opUpdate: + step.key = genKey() + step.value = make([]byte, 8) + binary.BigEndian.PutUint64(step.value, uint64(i)) + case opGet, opDelete, opProve: + step.key = genKey() + } + steps = append(steps, step) + } + return reflect.ValueOf(steps) +} + +func verifyAccessList(old *Trie, new *Trie, set *NodeSet) error { + deletes, inserts, updates := diffTries(old, new) + + // Check insertion set + for path := range inserts { + n, ok := set.nodes[path] + if !ok || n.isDeleted() { + return errors.New("expect new node") + } + _, ok = set.accessList[path] + if ok { + return errors.New("unexpected origin value") + } + } + // Check deletion set + for path, blob := range deletes { + n, ok := set.nodes[path] + if !ok || !n.isDeleted() { + return errors.New("expect deleted node") + } + v, ok := set.accessList[path] + if !ok { + return errors.New("expect origin value") + } + if !bytes.Equal(v, blob) { + return errors.New("invalid origin value") + } + } + // Check update set + for path, blob := range updates { + n, ok := set.nodes[path] + if !ok || n.isDeleted() { + return errors.New("expect updated node") + } + v, ok := set.accessList[path] + if !ok { + return errors.New("expect origin value") + } + if !bytes.Equal(v, blob) { + return errors.New("invalid origin value") + } + } + return nil +} + +func runRandTest(rt randTest) bool { + var ( + triedb = NewDatabase(rawdb.NewMemoryDatabase()) + tr = NewEmpty(triedb) + values = make(map[string]string) // tracks content of the trie + origTrie = NewEmpty(triedb) + ) + for i, step := range rt { + // fmt.Printf("{op: %d, key: common.Hex2Bytes(\"%x\"), value: common.Hex2Bytes(\"%x\")}, // step %d\n", + // step.op, step.key, step.value, i) + + switch step.op { + case opUpdate: + tr.Update(step.key, step.value) + values[string(step.key)] = string(step.value) + case opDelete: + tr.Delete(step.key) + delete(values, string(step.key)) + case opGet: + v := tr.Get(step.key) + want := values[string(step.key)] + if string(v) != want { + rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want) + } + case opProve: + hash := tr.Hash() + if hash == types.EmptyRootHash { + continue + } + proofDb := rawdb.NewMemoryDatabase() + err := tr.Prove(step.key, 0, proofDb) + if err != nil { + rt[i].err = fmt.Errorf("failed for proving key %#x, %v", step.key, err) + } + _, err = VerifyProof(hash, step.key, proofDb) + if err != nil { + rt[i].err = fmt.Errorf("failed for verifying key %#x, %v", step.key, err) + } + case opHash: + tr.Hash() + case opCommit: + root, nodes := tr.Commit(true) + if nodes != nil { + triedb.Update(NewWithNodeSet(nodes)) + } + newtr, err := NewAccountTrie(TrieID(root), triedb) + if err != nil { + rt[i].err = err + return false + } + if nodes != nil { + if err := verifyAccessList(origTrie, newtr, nodes); err != nil { + rt[i].err = err + return false + } + } + tr = newtr + origTrie = tr.Copy() + case opItercheckhash: + checktr := NewEmpty(triedb) + it := NewIterator(tr.NodeIterator(nil)) + for it.Next() { + checktr.Update(it.Key, it.Value) + } + if tr.Hash() != checktr.Hash() { + rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash") + } + case opNodeDiff: + var ( + origIter = origTrie.NodeIterator(nil) + curIter = tr.NodeIterator(nil) + origSeen = make(map[string]struct{}) + curSeen = make(map[string]struct{}) + ) + for origIter.Next(true) { + if origIter.Leaf() { + continue + } + origSeen[string(origIter.Path())] = struct{}{} + } + for curIter.Next(true) { + if curIter.Leaf() { + continue + } + curSeen[string(curIter.Path())] = struct{}{} + } + var ( + insertExp = make(map[string]struct{}) + deleteExp = make(map[string]struct{}) + ) + for path := range curSeen { + _, present := origSeen[path] + if !present { + insertExp[path] = struct{}{} + } + } + for path := range origSeen { + _, present := curSeen[path] + if !present { + deleteExp[path] = struct{}{} + } + } + if len(insertExp) != len(tr.tracer.inserts) { + rt[i].err = fmt.Errorf("insert set mismatch") + } + if len(deleteExp) != len(tr.tracer.deletes) { + rt[i].err = fmt.Errorf("delete set mismatch") + } + for insert := range tr.tracer.inserts { + if _, present := insertExp[insert]; !present { + rt[i].err = fmt.Errorf("missing inserted node") + } + } + for del := range tr.tracer.deletes { + if _, present := deleteExp[del]; !present { + rt[i].err = fmt.Errorf("missing deleted node") + } + } + } + // Abort the test on error. + if rt[i].err != nil { + return false + } + } + return true +} + +func TestRandom(t *testing.T) { + if err := quick.Check(runRandTest, nil); err != nil { + if cerr, ok := err.(*quick.CheckError); ok { + t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In)) + } + t.Fatal(err) + } +} + +func TestTinyTrie(t *testing.T) { // Create a realistic account trie to hash _, accounts := makeAccounts(5) - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) type testCase struct { key, account []byte root common.Hash } + cases := []testCase{ { common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000001337"), @@ -92,48 +601,48 @@ func TestTrieTiny(t *testing.T) { common.HexToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), }, } - for i, tc := range cases { - t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { - origtrie.Update(tc.key, tc.account) - trie := indexTrie(t, edb, commitTrie(t, db, origtrie)) - if exp, root := tc.root, trie.Hash(); exp != root { - t.Errorf("got %x, exp %x", root, exp) - } - checkValue(t, trie, tc.key) - }) + for i, c := range cases { + trie.Update(c.key, c.account) + root := trie.Hash() + if root != c.root { + t.Errorf("case %d: got %x, exp %x", i, root, c.root) + } + } + checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) + it := NewIterator(trie.NodeIterator(nil)) + for it.Next() { + checktr.Update(it.Key, it.Value) + } + if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot { + t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot) } } -func TestTrieMedium(t *testing.T) { +func TestCommitAfterHash(t *testing.T) { // Create a realistic account trie to hash addresses, accounts := makeAccounts(1000) - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - origtrie := geth_trie.NewEmpty(db) - var keys [][]byte + trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase())) for i := 0; i < len(addresses); i++ { - key := crypto.Keccak256(addresses[i][:]) - if i%50 == 0 { - keys = append(keys, key) - } - origtrie.Update(key, accounts[i]) + trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) } - tr := indexTrie(t, edb, commitTrie(t, db, origtrie)) - - root := tr.Hash() + // Insert the accounts into the trie and hash it + trie.Hash() + trie.Commit(false) + root := trie.Hash() exp := common.HexToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6") if exp != root { t.Errorf("got %x, exp %x", root, exp) } - - for _, key := range keys { - checkValue(t, tr, key) + root, _ = trie.Commit(false) + if exp != root { + t.Errorf("got %x, exp %x", root, exp) } } -// Make deterministically random accounts func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { + // Make the random benchmark deterministic random := rand.New(rand.NewSource(0)) + // Create a realistic account trie to hash addresses = make([][20]byte, size) for i := 0; i < len(addresses); i++ { data := make([]byte, 20) @@ -149,25 +658,40 @@ func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) { ) // The big.Rand function is not deterministic with regards to 64 vs 32 bit systems, // and will consume different amount of data from the rand source. - // balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + //balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) // Therefore, we instead just read via byte buffer numBytes := random.Uint32() % 33 // [0, 32] bytes balanceBytes := make([]byte, numBytes) random.Read(balanceBytes) balance := new(big.Int).SetBytes(balanceBytes) - acct := &types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code} - data, _ := rlp.EncodeToBytes(acct) + data, _ := rlp.EncodeToBytes(&types.StateAccount{Nonce: nonce, Balance: balance, Root: root, CodeHash: code}) accounts[i] = data } return addresses, accounts } -func checkValue(t *testing.T, tr *trie.Trie, key []byte) { - val, err := tr.TryGet(key) - if err != nil { - t.Fatalf("error getting node: %s", err) - } - if len(val) == 0 { - t.Errorf("failed to get value for %x", key) +func getString(trie *Trie, k string) []byte { + return trie.Get([]byte(k)) +} + +func updateString(trie *Trie, k, v string) { + trie.Update([]byte(k), []byte(v)) +} + +func deleteString(trie *Trie, k string) { + trie.Delete([]byte(k)) +} + +func TestDecodeNode(t *testing.T) { + t.Parallel() + + var ( + hash = make([]byte, 20) + elems = make([]byte, 20) + ) + for i := 0; i < 5000000; i++ { + prng.Read(hash) + prng.Read(elems) + decodeNode(hash, elems) } } diff --git a/trie_by_cid/trie/util_test.go b/trie_by_cid/trie/util_test.go index 3272826..a75ac8c 100644 --- a/trie_by_cid/trie/util_test.go +++ b/trie_by_cid/trie/util_test.go @@ -1,43 +1,133 @@ -package trie_test +package trie import ( + "bytes" + "context" "fmt" "math/big" "math/rand" "testing" - "time" - - "github.com/jmoiron/sqlx" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" - geth_state "github.com/ethereum/go-ethereum/core/state" + gethstate "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/rlp" - geth_trie "github.com/ethereum/go-ethereum/trie" + gethtrie "github.com/ethereum/go-ethereum/trie" + "github.com/jmoiron/sqlx" pgipfsethdb "github.com/cerc-io/ipfs-ethdb/v5/postgres/v0" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/state" - "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/trie" "github.com/ethereum/go-ethereum/statediff/indexer/database/sql/postgres" - "github.com/ethereum/go-ethereum/statediff/indexer/ipld" "github.com/ethereum/go-ethereum/statediff/test_helpers" + + "github.com/cerc-io/ipld-eth-statedb/internal" + "github.com/cerc-io/ipld-eth-statedb/trie_by_cid/helper" ) -type kv struct { +var ( + dbConfig, _ = postgres.DefaultConfig.WithEnv() + trieConfig = Config{Cache: 256} +) + +type kvi struct { k []byte v int64 } -type kvMap map[string]*kv +type kvMap map[string]*kvi -type kvs struct { +type kvsi struct { k string v int64 } +// NewAccountTrie is a shortcut to create a trie using the StateTrieCodec (ie. IPLD MEthStateTrie codec). +func NewAccountTrie(id *ID, db NodeReader) (*Trie, error) { + return New(id, db, StateTrieCodec) +} + +// makeTestTrie create a sample test trie to test node-wise reconstruction. +func makeTestTrie(t testing.TB) (*Database, *StateTrie, map[string][]byte) { + // Create an empty trie + triedb := NewDatabase(rawdb.NewMemoryDatabase()) + trie, err := NewStateTrie(TrieID(common.Hash{}), triedb, StateTrieCodec) + if err != nil { + t.Fatal(err) + } + + // Fill it with some arbitrary data + content := make(map[string][]byte) + for i := byte(0); i < 255; i++ { + // Map the same data under multiple keys + key, val := common.LeftPadBytes([]byte{1, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + key, val = common.LeftPadBytes([]byte{2, i}, 32), []byte{i} + content[string(key)] = val + trie.Update(key, val) + + // Add some other data to inflate the trie + for j := byte(3); j < 13; j++ { + key, val = common.LeftPadBytes([]byte{j, i}, 32), []byte{j, i} + content[string(key)] = val + trie.Update(key, val) + } + } + root, nodes := trie.Commit(false) + if err := triedb.Update(NewWithNodeSet(nodes)); err != nil { + panic(fmt.Errorf("failed to commit db %v", err)) + } + // Re-create the trie based on the new state + trie, err = NewStateTrie(TrieID(root), triedb, StateTrieCodec) + if err != nil { + t.Fatal(err) + } + return triedb, trie, content +} + +func forHashedNodes(tr *Trie) map[string][]byte { + var ( + it = tr.NodeIterator(nil) + nodes = make(map[string][]byte) + ) + for it.Next(true) { + if it.Hash() == (common.Hash{}) { + continue + } + nodes[string(it.Path())] = common.CopyBytes(it.NodeBlob()) + } + return nodes +} + +func diffTries(trieA, trieB *Trie) (map[string][]byte, map[string][]byte, map[string][]byte) { + var ( + nodesA = forHashedNodes(trieA) + nodesB = forHashedNodes(trieB) + inA = make(map[string][]byte) // hashed nodes in trie a but not b + inB = make(map[string][]byte) // hashed nodes in trie b but not a + both = make(map[string][]byte) // hashed nodes in both tries but different value + ) + for path, blobA := range nodesA { + if blobB, ok := nodesB[path]; ok { + if bytes.Equal(blobA, blobB) { + continue + } + both[path] = blobA + continue + } + inA[path] = blobA + } + for path, blobB := range nodesB { + if _, ok := nodesA[path]; ok { + continue + } + inB[path] = blobB + } + return inA, inB, both +} + func packValue(val int64) []byte { acct := &types.StateAccount{ Balance: big.NewInt(val), @@ -51,27 +141,19 @@ func packValue(val int64) []byte { return acct_rlp } -func unpackValue(val []byte) int64 { - var acct types.StateAccount - if err := rlp.DecodeBytes(val, &acct); err != nil { - panic(err) - } - return acct.Balance.Int64() -} - -func updateTrie(tr *geth_trie.Trie, vals []kvs) (kvMap, error) { +func updateTrie(tr *gethtrie.Trie, vals []kvsi) (kvMap, error) { all := kvMap{} for _, val := range vals { - all[string(val.k)] = &kv{[]byte(val.k), val.v} + all[string(val.k)] = &kvi{[]byte(val.k), val.v} tr.Update([]byte(val.k), packValue(val.v)) } return all, nil } -func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common.Hash { +func commitTrie(t testing.TB, db *gethtrie.Database, tr *gethtrie.Trie) common.Hash { t.Helper() root, nodes := tr.Commit(false) - if err := db.Update(geth_trie.NewWithNodeSet(nodes)); err != nil { + if err := db.Update(gethtrie.NewWithNodeSet(nodes)); err != nil { t.Fatal(err) } if err := db.Commit(root, false); err != nil { @@ -80,16 +162,8 @@ func commitTrie(t testing.TB, db *geth_trie.Database, tr *geth_trie.Trie) common return root } -// commit a LevelDB state trie, index to IPLD and return new trie -func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie { - t.Helper() - dbConfig.Driver = postgres.PGX - err := helper.IndexChain(dbConfig, geth_state.NewDatabase(edb), common.Hash{}, root) - if err != nil { - t.Fatal(err) - } - - pg_db, err := postgres.ConnectSQLX(ctx, dbConfig) +func makePgIpfsEthDB(t testing.TB) ethdb.Database { + pg_db, err := postgres.ConnectSQLX(context.Background(), dbConfig) if err != nil { t.Fatal(err) } @@ -98,62 +172,48 @@ func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *trie.Trie { t.Fatal(err) } }) + return pgipfsethdb.NewDatabase(pg_db, internal.MakeCacheConfig(t)) +} - ipfs_db := pgipfsethdb.NewDatabase(pg_db, makeCacheConfig(t)) - sdb_db := state.NewDatabase(ipfs_db) - tr, err := newStateTrie(root, sdb_db.TrieDB()) +// commit a LevelDB state trie, index to IPLD and return new trie +func indexTrie(t testing.TB, edb ethdb.Database, root common.Hash) *Trie { + t.Helper() + dbConfig.Driver = postgres.PGX + err := helper.IndexStateDiff(dbConfig, gethstate.NewDatabase(edb), common.Hash{}, root) + if err != nil { + t.Fatal(err) + } + + ipfs_db := makePgIpfsEthDB(t) + tr, err := New(TrieID(root), NewDatabase(ipfs_db), StateTrieCodec) if err != nil { t.Fatal(err) } return tr } -func newStateTrie(root common.Hash, db *trie.Database) (*trie.Trie, error) { - tr, err := trie.New(common.Hash{}, root, db, ipld.MEthStateTrie) - if err != nil { - return nil, err - } - return tr, nil -} - // generates a random Geth LevelDB trie of n key-value pairs and corresponding value map -func randomGethTrie(n int, db *geth_trie.Database) (*geth_trie.Trie, kvMap) { - trie := geth_trie.NewEmpty(db) - var vals []*kv +func randomGethTrie(n int, db *gethtrie.Database) (*gethtrie.Trie, kvMap) { + trie := gethtrie.NewEmpty(db) + var vals []*kvi for i := byte(0); i < 100; i++ { - e := &kv{common.LeftPadBytes([]byte{i}, 32), int64(i)} - e2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)} + e := &kvi{common.LeftPadBytes([]byte{i}, 32), int64(i)} + e2 := &kvi{common.LeftPadBytes([]byte{i + 10}, 32), int64(i)} vals = append(vals, e, e2) } for i := 0; i < n; i++ { k := randBytes(32) v := rand.Int63() - vals = append(vals, &kv{k, v}) + vals = append(vals, &kvi{k, v}) } all := kvMap{} for _, val := range vals { - all[string(val.k)] = &kv{[]byte(val.k), val.v} + all[string(val.k)] = &kvi{[]byte(val.k), val.v} trie.Update([]byte(val.k), packValue(val.v)) } return trie, all } -// generates a random IPLD-indexed trie -func randomTrie(t testing.TB, n int) (*trie.Trie, kvMap) { - edb := rawdb.NewMemoryDatabase() - db := geth_trie.NewDatabase(edb) - orig, vals := randomGethTrie(n, db) - root := commitTrie(t, db, orig) - trie := indexTrie(t, edb, root) - return trie, vals -} - -func randBytes(n int) []byte { - r := make([]byte, n) - rand.Read(r) - return r -} - // TearDownDB is used to tear down the watcher dbs after tests func TearDownDB(db *sqlx.DB) error { tx, err := db.Beginx() @@ -179,12 +239,3 @@ func TearDownDB(db *sqlx.DB) error { } return tx.Commit() } - -// returns a cache config with unique name (groupcache names are global) -func makeCacheConfig(t testing.TB) pgipfsethdb.CacheConfig { - return pgipfsethdb.CacheConfig{ - Name: t.Name(), - Size: 3000000, // 3MB - ExpiryDuration: time.Hour, - } -}