core/state: access trie through Database interface, track errors (#14589)

With this commit, core/state's access to the underlying key/value database is
mediated through an interface. Database errors are tracked in StateDB and
returned by CommitTo or the new Error method.

Motivation for this change: We can remove the light client's duplicated copy of
core/state. The light client now supports node iteration, so tracing and storage
enumeration can work with the light client (not implemented in this commit).
This commit is contained in:
Felix Lange 2017-06-27 15:57:06 +02:00 committed by GitHub
parent bb366271fe
commit 9e5f03b6c4
49 changed files with 810 additions and 1664 deletions

View File

@ -90,7 +90,7 @@ func (b *SimulatedBackend) Rollback() {
func (b *SimulatedBackend) rollback() { func (b *SimulatedBackend) rollback() {
blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), b.database, 1, func(int, *core.BlockGen) {}) blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), b.database, 1, func(int, *core.BlockGen) {})
b.pendingBlock = blocks[0] b.pendingBlock = blocks[0]
b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database) b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
} }
// CodeAt returns the code associated with a certain account in the blockchain. // CodeAt returns the code associated with a certain account in the blockchain.
@ -279,7 +279,7 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
block.AddTx(tx) block.AddTx(tx)
}) })
b.pendingBlock = blocks[0] b.pendingBlock = blocks[0]
b.pendingState, _ = state.New(b.pendingBlock.Root(), b.database) b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
return nil return nil
} }

View File

@ -98,8 +98,8 @@ func runCmd(ctx *cli.Context) error {
_, statedb = gen.ToBlock() _, statedb = gen.ToBlock()
chainConfig = gen.Config chainConfig = gen.Config
} else { } else {
var db, _ = ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ = state.New(common.Hash{}, db) statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
} }
if ctx.GlobalString(SenderFlag.Name) != "" { if ctx.GlobalString(SenderFlag.Name) != "" {
sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name)) sender = common.HexToAddress(ctx.GlobalString(SenderFlag.Name))
@ -188,7 +188,7 @@ func runCmd(ctx *cli.Context) error {
execTime := time.Since(tstart) execTime := time.Since(tstart)
if ctx.GlobalBool(DumpFlag.Name) { if ctx.GlobalBool(DumpFlag.Name) {
statedb.Commit(true) statedb.IntermediateRoot(true)
fmt.Println(string(statedb.Dump())) fmt.Println(string(statedb.Dump()))
} }

View File

@ -312,7 +312,7 @@ func dump(ctx *cli.Context) error {
fmt.Println("{}") fmt.Println("{}")
utils.Fatalf("block not found") utils.Fatalf("block not found")
} else { } else {
state, err := state.New(block.Root(), chainDb) state, err := state.New(block.Root(), state.NewDatabase(chainDb))
if err != nil { if err != nil {
utils.Fatalf("could not create new state: %v", err) utils.Fatalf("could not create new state: %v", err)
} }

View File

@ -52,16 +52,10 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin
// validated at this point. // validated at this point.
func (v *BlockValidator) ValidateBody(block *types.Block) error { func (v *BlockValidator) ValidateBody(block *types.Block) error {
// Check whether the block's known, and if not, that it's linkable // Check whether the block's known, and if not, that it's linkable
if v.bc.HasBlock(block.Hash()) { if v.bc.HasBlockAndState(block.Hash()) {
if _, err := state.New(block.Root(), v.bc.chainDb); err == nil {
return ErrKnownBlock return ErrKnownBlock
} }
} if !v.bc.HasBlockAndState(block.ParentHash()) {
parent := v.bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
if parent == nil {
return consensus.ErrUnknownAncestor
}
if _, err := state.New(parent.Root(), v.bc.chainDb); err != nil {
return consensus.ErrUnknownAncestor return consensus.ErrUnknownAncestor
} }
// Header validity is known at this point, check the uncles and transactions // Header validity is known at this point, check the uncles and transactions

View File

@ -92,7 +92,7 @@ type BlockChain struct {
currentBlock *types.Block // Current head of the block chain currentBlock *types.Block // Current head of the block chain
currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!) currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!)
stateCache *state.StateDB // State database to reuse between imports (contains state cache) stateCache state.Database // State database to reuse between imports (contains state cache)
bodyCache *lru.Cache // Cache for the most recent block bodies bodyCache *lru.Cache // Cache for the most recent block bodies
bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format bodyRLPCache *lru.Cache // Cache for the most recent block bodies in RLP encoded format
blockCache *lru.Cache // Cache for the most recent entire blocks blockCache *lru.Cache // Cache for the most recent entire blocks
@ -125,6 +125,7 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
bc := &BlockChain{ bc := &BlockChain{
config: config, config: config,
chainDb: chainDb, chainDb: chainDb,
stateCache: state.NewDatabase(chainDb),
eventMux: mux, eventMux: mux,
quit: make(chan struct{}), quit: make(chan struct{}),
bodyCache: bodyCache, bodyCache: bodyCache,
@ -190,7 +191,7 @@ func (bc *BlockChain) loadLastState() error {
return bc.Reset() return bc.Reset()
} }
// Make sure the state associated with the block is available // Make sure the state associated with the block is available
if _, err := state.New(currentBlock.Root(), bc.chainDb); err != nil { if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil {
// Dangling block without a state associated, init from scratch // Dangling block without a state associated, init from scratch
log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash()) log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
return bc.Reset() return bc.Reset()
@ -214,12 +215,6 @@ func (bc *BlockChain) loadLastState() error {
bc.currentFastBlock = block bc.currentFastBlock = block
} }
} }
// Initialize a statedb cache to ensure singleton account bloom filter generation
statedb, err := state.New(bc.currentBlock.Root(), bc.chainDb)
if err != nil {
return err
}
bc.stateCache = statedb
// Issue a status log for the user // Issue a status log for the user
headerTd := bc.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64()) headerTd := bc.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64())
@ -261,7 +256,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
bc.currentBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()) bc.currentBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())
} }
if bc.currentBlock != nil { if bc.currentBlock != nil {
if _, err := state.New(bc.currentBlock.Root(), bc.chainDb); err != nil { if _, err := state.New(bc.currentBlock.Root(), bc.stateCache); err != nil {
// Rewound state missing, rolled back to before pivot, reset to genesis // Rewound state missing, rolled back to before pivot, reset to genesis
bc.currentBlock = nil bc.currentBlock = nil
} }
@ -384,7 +379,7 @@ func (bc *BlockChain) State() (*state.StateDB, error) {
// StateAt returns a new mutable state based on a particular point in time. // StateAt returns a new mutable state based on a particular point in time.
func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) { func (bc *BlockChain) StateAt(root common.Hash) (*state.StateDB, error) {
return bc.stateCache.New(root) return state.New(root, bc.stateCache)
} }
// Reset purges the entire blockchain, restoring it to its genesis state. // Reset purges the entire blockchain, restoring it to its genesis state.
@ -531,7 +526,7 @@ func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool {
return false return false
} }
// Ensure the associated state is also present // Ensure the associated state is also present
_, err := state.New(block.Root(), bc.chainDb) _, err := bc.stateCache.OpenTrie(block.Root())
return err == nil return err == nil
} }
@ -959,31 +954,30 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) {
} }
// Create a new statedb using the parent block and report an // Create a new statedb using the parent block and report an
// error if it fails. // error if it fails.
switch { var parent *types.Block
case i == 0: if i == 0 {
err = bc.stateCache.Reset(bc.GetBlock(block.ParentHash(), block.NumberU64()-1).Root()) parent = bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
default: } else {
err = bc.stateCache.Reset(chain[i-1].Root()) parent = chain[i-1]
} }
state, err := state.New(parent.Root(), bc.stateCache)
if err != nil { if err != nil {
bc.reportBlock(block, nil, err)
return i, err return i, err
} }
// Process block using the parent state as reference point. // Process block using the parent state as reference point.
receipts, logs, usedGas, err := bc.processor.Process(block, bc.stateCache, bc.vmConfig) receipts, logs, usedGas, err := bc.processor.Process(block, state, bc.vmConfig)
if err != nil { if err != nil {
bc.reportBlock(block, receipts, err) bc.reportBlock(block, receipts, err)
return i, err return i, err
} }
// Validate the state using the default validator // Validate the state using the default validator
err = bc.Validator().ValidateState(block, bc.GetBlock(block.ParentHash(), block.NumberU64()-1), bc.stateCache, receipts, usedGas) err = bc.Validator().ValidateState(block, parent, state, receipts, usedGas)
if err != nil { if err != nil {
bc.reportBlock(block, receipts, err) bc.reportBlock(block, receipts, err)
return i, err return i, err
} }
// Write state changes to database // Write state changes to database
_, err = bc.stateCache.Commit(bc.config.IsEIP158(block.Number())) if _, err = state.CommitTo(bc.chainDb, bc.config.IsEIP158(block.Number())); err != nil {
if err != nil {
return i, err return i, err
} }
@ -1021,7 +1015,7 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) {
return i, err return i, err
} }
// Write hash preimages // Write hash preimages
if err := WritePreimages(bc.chainDb, block.NumberU64(), bc.stateCache.Preimages()); err != nil { if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil {
return i, err return i, err
} }
case SideStatTy: case SideStatTy:

View File

@ -131,7 +131,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
} }
return err return err
} }
statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.chainDb) statedb, err := state.New(blockchain.GetBlockByHash(block.ParentHash()).Root(), blockchain.stateCache)
if err != nil { if err != nil {
return err return err
} }
@ -148,7 +148,7 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
blockchain.mu.Lock() blockchain.mu.Lock()
WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash()))) WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
WriteBlock(blockchain.chainDb, block) WriteBlock(blockchain.chainDb, block)
statedb.Commit(false) statedb.CommitTo(blockchain.chainDb, false)
blockchain.mu.Unlock() blockchain.mu.Unlock()
} }
return nil return nil
@ -1131,7 +1131,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
if _, err := blockchain.InsertChain(types.Blocks{blocks[0]}); err != nil { if _, err := blockchain.InsertChain(types.Blocks{blocks[0]}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !blockchain.stateCache.Exist(theAddr) { if st, _ := blockchain.State(); !st.Exist(theAddr) {
t.Error("expected account to exist") t.Error("expected account to exist")
} }
@ -1139,7 +1139,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
if _, err := blockchain.InsertChain(types.Blocks{blocks[1]}); err != nil { if _, err := blockchain.InsertChain(types.Blocks{blocks[1]}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if blockchain.stateCache.Exist(theAddr) { if st, _ := blockchain.State(); st.Exist(theAddr) {
t.Error("account should not exist") t.Error("account should not exist")
} }
@ -1147,7 +1147,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil { if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if blockchain.stateCache.Exist(theAddr) { if st, _ := blockchain.State(); st.Exist(theAddr) {
t.Error("account should not exist") t.Error("account should not exist")
} }
} }

View File

@ -181,7 +181,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat
gen(i, b) gen(i, b)
} }
ethash.AccumulateRewards(statedb, h, b.uncles) ethash.AccumulateRewards(statedb, h, b.uncles)
root, err := statedb.Commit(config.IsEIP158(h.Number)) root, err := statedb.CommitTo(db, config.IsEIP158(h.Number))
if err != nil { if err != nil {
panic(fmt.Sprintf("state write error: %v", err)) panic(fmt.Sprintf("state write error: %v", err))
} }
@ -189,7 +189,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, db ethdb.Dat
return types.NewBlock(h, b.txs, b.uncles, b.receipts), b.receipts return types.NewBlock(h, b.txs, b.uncles, b.receipts), b.receipts
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
statedb, err := state.New(parent.Root(), db) statedb, err := state.New(parent.Root(), state.NewDatabase(db))
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -176,7 +176,7 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig {
// ToBlock creates the block and state of a genesis specification. // ToBlock creates the block and state of a genesis specification.
func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) { func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
for addr, account := range g.Alloc { for addr, account := range g.Alloc {
statedb.AddBalance(addr, account.Balance) statedb.AddBalance(addr, account.Balance)
statedb.SetCode(addr, account.Code) statedb.SetCode(addr, account.Code)

154
core/state/database.go Normal file
View File

@ -0,0 +1,154 @@
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package state
import (
"fmt"
"sync"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie"
lru "github.com/hashicorp/golang-lru"
)
// Trie cache generation limit after which to evic trie nodes from memory.
var MaxTrieCacheGen = uint16(120)
const (
// Number of past tries to keep. This value is chosen such that
// reasonable chain reorg depths will hit an existing trie.
maxPastTries = 12
// Number of codehash->size associations to keep.
codeSizeCacheSize = 100000
)
// Database wraps access to tries and contract code.
type Database interface {
// Accessing tries:
// OpenTrie opens the main account trie.
// OpenStorageTrie opens the storage trie of an account.
OpenTrie(root common.Hash) (Trie, error)
OpenStorageTrie(addrHash, root common.Hash) (Trie, error)
// Accessing contract code:
ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
// CopyTrie returns an independent copy of the given trie.
CopyTrie(Trie) Trie
}
// Trie is a Ethereum Merkle Trie.
type Trie interface {
TryGet(key []byte) ([]byte, error)
TryUpdate(key, value []byte) error
TryDelete(key []byte) error
CommitTo(trie.DatabaseWriter) (common.Hash, error)
Hash() common.Hash
NodeIterator(startKey []byte) trie.NodeIterator
GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed
}
// NewDatabase creates a backing store for state. The returned database is safe for
// concurrent use and retains cached trie nodes in memory.
func NewDatabase(db ethdb.Database) Database {
csc, _ := lru.New(codeSizeCacheSize)
return &cachingDB{db: db, codeSizeCache: csc}
}
type cachingDB struct {
db ethdb.Database
mu sync.Mutex
pastTries []*trie.SecureTrie
codeSizeCache *lru.Cache
}
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
db.mu.Lock()
defer db.mu.Unlock()
for i := len(db.pastTries) - 1; i >= 0; i-- {
if db.pastTries[i].Hash() == root {
return cachedTrie{db.pastTries[i].Copy(), db}, nil
}
}
tr, err := trie.NewSecure(root, db.db, MaxTrieCacheGen)
if err != nil {
return nil, err
}
return cachedTrie{tr, db}, nil
}
func (db *cachingDB) pushTrie(t *trie.SecureTrie) {
db.mu.Lock()
defer db.mu.Unlock()
if len(db.pastTries) >= maxPastTries {
copy(db.pastTries, db.pastTries[1:])
db.pastTries[len(db.pastTries)-1] = t
} else {
db.pastTries = append(db.pastTries, t)
}
}
func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
return trie.NewSecure(root, db.db, 0)
}
func (db *cachingDB) CopyTrie(t Trie) Trie {
switch t := t.(type) {
case cachedTrie:
return cachedTrie{t.SecureTrie.Copy(), db}
case *trie.SecureTrie:
return t.Copy()
default:
panic(fmt.Errorf("unknown trie type %T", t))
}
}
func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
code, err := db.db.Get(codeHash[:])
if err == nil {
db.codeSizeCache.Add(codeHash, len(code))
}
return code, err
}
func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
if cached, ok := db.codeSizeCache.Get(codeHash); ok {
return cached.(int), nil
}
code, err := db.ContractCode(addrHash, codeHash)
if err == nil {
db.codeSizeCache.Add(codeHash, len(code))
}
return len(code), err
}
// cachedTrie inserts its trie into a cachingDB on commit.
type cachedTrie struct {
*trie.SecureTrie
db *cachingDB
}
func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) {
root, err := m.SecureTrie.CommitTo(dbw)
if err == nil {
m.db.pushTrie(m.SecureTrie)
}
return root, err
}

View File

@ -41,7 +41,7 @@ type Dump struct {
func (self *StateDB) RawDump() Dump { func (self *StateDB) RawDump() Dump {
dump := Dump{ dump := Dump{
Root: common.Bytes2Hex(self.trie.Root()), Root: fmt.Sprintf("%x", self.trie.Hash()),
Accounts: make(map[string]DumpAccount), Accounts: make(map[string]DumpAccount),
} }

View File

@ -19,7 +19,6 @@ package state
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math/big"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
@ -105,16 +104,11 @@ func (it *NodeIterator) step() error {
return nil return nil
} }
// Otherwise we've reached an account node, initiate data iteration // Otherwise we've reached an account node, initiate data iteration
var account struct { var account Account
Nonce uint64
Balance *big.Int
Root common.Hash
CodeHash []byte
}
if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil { if err := rlp.Decode(bytes.NewReader(it.stateIt.LeafBlob()), &account); err != nil {
return err return err
} }
dataTrie, err := trie.New(account.Root, it.state.db) dataTrie, err := it.state.db.OpenStorageTrie(common.BytesToHash(it.stateIt.LeafKey()), account.Root)
if err != nil { if err != nil {
return err return err
} }
@ -124,7 +118,8 @@ func (it *NodeIterator) step() error {
} }
if !bytes.Equal(account.CodeHash, emptyCodeHash) { if !bytes.Equal(account.CodeHash, emptyCodeHash) {
it.codeHash = common.BytesToHash(account.CodeHash) it.codeHash = common.BytesToHash(account.CodeHash)
it.code, err = it.state.db.Get(account.CodeHash) addrHash := common.BytesToHash(it.stateIt.LeafKey())
it.code, err = it.state.db.ContractCode(addrHash, common.BytesToHash(account.CodeHash))
if err != nil { if err != nil {
return fmt.Errorf("code %x: %v", account.CodeHash, err) return fmt.Errorf("code %x: %v", account.CodeHash, err)
} }

View File

@ -21,13 +21,12 @@ import (
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethdb"
) )
// Tests that the node iterator indeed walks over the entire database contents. // Tests that the node iterator indeed walks over the entire database contents.
func TestNodeIteratorCoverage(t *testing.T) { func TestNodeIteratorCoverage(t *testing.T) {
// Create some arbitrary test state to iterate // Create some arbitrary test state to iterate
db, root, _ := makeTestState() db, mem, root, _ := makeTestState()
state, err := New(root, db) state, err := New(root, db)
if err != nil { if err != nil {
@ -40,13 +39,14 @@ func TestNodeIteratorCoverage(t *testing.T) {
hashes[it.Hash] = struct{}{} hashes[it.Hash] = struct{}{}
} }
} }
// Cross check the hashes and the database itself // Cross check the hashes and the database itself
for hash := range hashes { for hash := range hashes {
if _, err := db.Get(hash.Bytes()); err != nil { if _, err := mem.Get(hash.Bytes()); err != nil {
t.Errorf("failed to retrieve reported node %x: %v", hash, err) t.Errorf("failed to retrieve reported node %x: %v", hash, err)
} }
} }
for _, key := range db.(*ethdb.MemDatabase).Keys() { for _, key := range mem.Keys() {
if bytes.HasPrefix(key, []byte("secure-key-")) { if bytes.HasPrefix(key, []byte("secure-key-")) {
continue continue
} }

View File

@ -27,7 +27,7 @@ var addr = common.BytesToAddress([]byte("test"))
func create() (*ManagedState, *account) { func create() (*ManagedState, *account) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := New(common.Hash{}, db) statedb, _ := New(common.Hash{}, NewDatabase(db))
ms := ManageState(statedb) ms := ManageState(statedb)
ms.StateDB.SetNonce(addr, 100) ms.StateDB.SetNonce(addr, 100)
ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr)) ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr))

View File

@ -62,7 +62,8 @@ func (self Storage) Copy() Storage {
// Account values can be accessed and modified through the object. // Account values can be accessed and modified through the object.
// Finally, call CommitTrie to write the modified storage trie into a database. // Finally, call CommitTrie to write the modified storage trie into a database.
type stateObject struct { type stateObject struct {
address common.Address // Ethereum address of this account address common.Address
addrHash common.Hash // hash of ethereum address of the account
data Account data Account
db *StateDB db *StateDB
@ -74,7 +75,7 @@ type stateObject struct {
dbErr error dbErr error
// Write caches. // Write caches.
trie *trie.SecureTrie // storage trie, which becomes non-nil on first access trie Trie // storage trie, which becomes non-nil on first access
code Code // contract bytecode, which gets set when code is loaded code Code // contract bytecode, which gets set when code is loaded
cachedStorage Storage // Storage entry cache to avoid duplicate reads cachedStorage Storage // Storage entry cache to avoid duplicate reads
@ -112,7 +113,15 @@ func newObject(db *StateDB, address common.Address, data Account, onDirty func(a
if data.CodeHash == nil { if data.CodeHash == nil {
data.CodeHash = emptyCodeHash data.CodeHash = emptyCodeHash
} }
return &stateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} return &stateObject{
db: db,
address: address,
addrHash: crypto.Keccak256Hash(address[:]),
data: data,
cachedStorage: make(Storage),
dirtyStorage: make(Storage),
onDirty: onDirty,
}
} }
// EncodeRLP implements rlp.Encoder. // EncodeRLP implements rlp.Encoder.
@ -148,12 +157,12 @@ func (c *stateObject) touch() {
c.touched = true c.touched = true
} }
func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie { func (c *stateObject) getTrie(db Database) Trie {
if c.trie == nil { if c.trie == nil {
var err error var err error
c.trie, err = trie.NewSecure(c.data.Root, db, 0) c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root)
if err != nil { if err != nil {
c.trie, _ = trie.NewSecure(common.Hash{}, db, 0) c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{})
c.setError(fmt.Errorf("can't create storage trie: %v", err)) c.setError(fmt.Errorf("can't create storage trie: %v", err))
} }
} }
@ -161,13 +170,18 @@ func (c *stateObject) getTrie(db trie.Database) *trie.SecureTrie {
} }
// GetState returns a value in account storage. // GetState returns a value in account storage.
func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash { func (self *stateObject) GetState(db Database, key common.Hash) common.Hash {
value, exists := self.cachedStorage[key] value, exists := self.cachedStorage[key]
if exists { if exists {
return value return value
} }
// Load from DB in case it is missing. // Load from DB in case it is missing.
if enc := self.getTrie(db).Get(key[:]); len(enc) > 0 { enc, err := self.getTrie(db).TryGet(key[:])
if err != nil {
self.setError(err)
return common.Hash{}
}
if len(enc) > 0 {
_, content, _, err := rlp.Split(enc) _, content, _, err := rlp.Split(enc)
if err != nil { if err != nil {
self.setError(err) self.setError(err)
@ -181,7 +195,7 @@ func (self *stateObject) GetState(db trie.Database, key common.Hash) common.Hash
} }
// SetState updates a value in account storage. // SetState updates a value in account storage.
func (self *stateObject) SetState(db trie.Database, key, value common.Hash) { func (self *stateObject) SetState(db Database, key, value common.Hash) {
self.db.journal = append(self.db.journal, storageChange{ self.db.journal = append(self.db.journal, storageChange{
account: &self.address, account: &self.address,
key: key, key: key,
@ -201,30 +215,30 @@ func (self *stateObject) setState(key, value common.Hash) {
} }
// updateTrie writes cached storage modifications into the object's storage trie. // updateTrie writes cached storage modifications into the object's storage trie.
func (self *stateObject) updateTrie(db trie.Database) *trie.SecureTrie { func (self *stateObject) updateTrie(db Database) Trie {
tr := self.getTrie(db) tr := self.getTrie(db)
for key, value := range self.dirtyStorage { for key, value := range self.dirtyStorage {
delete(self.dirtyStorage, key) delete(self.dirtyStorage, key)
if (value == common.Hash{}) { if (value == common.Hash{}) {
tr.Delete(key[:]) self.setError(tr.TryDelete(key[:]))
continue continue
} }
// Encoding []byte cannot fail, ok to ignore the error. // Encoding []byte cannot fail, ok to ignore the error.
v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00"))
tr.Update(key[:], v) self.setError(tr.TryUpdate(key[:], v))
} }
return tr return tr
} }
// UpdateRoot sets the trie root to the current root hash of // UpdateRoot sets the trie root to the current root hash of
func (self *stateObject) updateRoot(db trie.Database) { func (self *stateObject) updateRoot(db Database) {
self.updateTrie(db) self.updateTrie(db)
self.data.Root = self.trie.Hash() self.data.Root = self.trie.Hash()
} }
// CommitTrie the storage trie of the object to dwb. // CommitTrie the storage trie of the object to dwb.
// This updates the trie root. // This updates the trie root.
func (self *stateObject) CommitTrie(db trie.Database, dbw trie.DatabaseWriter) error { func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error {
self.updateTrie(db) self.updateTrie(db)
if self.dbErr != nil { if self.dbErr != nil {
return self.dbErr return self.dbErr
@ -282,9 +296,7 @@ func (c *stateObject) ReturnGas(gas *big.Int) {}
func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject {
stateObject := newObject(db, self.address, self.data, onDirty) stateObject := newObject(db, self.address, self.data, onDirty)
if self.trie != nil { if self.trie != nil {
// A shallow copy makes the two tries independent. stateObject.trie = db.db.CopyTrie(self.trie)
cpy := *self.trie
stateObject.trie = &cpy
} }
stateObject.code = self.code stateObject.code = self.code
stateObject.dirtyStorage = self.dirtyStorage.Copy() stateObject.dirtyStorage = self.dirtyStorage.Copy()
@ -305,14 +317,14 @@ func (c *stateObject) Address() common.Address {
} }
// Code returns the contract code associated with this object, if any. // Code returns the contract code associated with this object, if any.
func (self *stateObject) Code(db trie.Database) []byte { func (self *stateObject) Code(db Database) []byte {
if self.code != nil { if self.code != nil {
return self.code return self.code
} }
if bytes.Equal(self.CodeHash(), emptyCodeHash) { if bytes.Equal(self.CodeHash(), emptyCodeHash) {
return nil return nil
} }
code, err := db.Get(self.CodeHash()) code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash()))
if err != nil { if err != nil {
self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err))
} }

View File

@ -21,14 +21,14 @@ import (
"math/big" "math/big"
"testing" "testing"
checker "gopkg.in/check.v1"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
checker "gopkg.in/check.v1"
) )
type StateSuite struct { type StateSuite struct {
db *ethdb.MemDatabase
state *StateDB state *StateDB
} }
@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) {
// write some of them to the trie // write some of them to the trie
s.state.updateStateObject(obj1) s.state.updateStateObject(obj1)
s.state.updateStateObject(obj2) s.state.updateStateObject(obj2)
s.state.Commit(false) s.state.CommitTo(s.db, false)
// check that dump contains the state objects that are in trie // check that dump contains the state objects that are in trie
got := string(s.state.Dump()) got := string(s.state.Dump())
@ -87,23 +87,20 @@ func (s *StateSuite) TestDump(c *checker.C) {
} }
func (s *StateSuite) SetUpTest(c *checker.C) { func (s *StateSuite) SetUpTest(c *checker.C) {
db, _ := ethdb.NewMemDatabase() s.db, _ = ethdb.NewMemDatabase()
s.state, _ = New(common.Hash{}, db) s.state, _ = New(common.Hash{}, NewDatabase(s.db))
} }
func TestNull(t *testing.T) { func (s *StateSuite) TestNull(c *checker.C) {
db, _ := ethdb.NewMemDatabase()
state, _ := New(common.Hash{}, db)
address := common.HexToAddress("0x823140710bf13990e4500136726d8b55") address := common.HexToAddress("0x823140710bf13990e4500136726d8b55")
state.CreateAccount(address) s.state.CreateAccount(address)
//value := common.FromHex("0x823140710bf13990e4500136726d8b55") //value := common.FromHex("0x823140710bf13990e4500136726d8b55")
var value common.Hash var value common.Hash
state.SetState(address, common.Hash{}, value) s.state.SetState(address, common.Hash{}, value)
state.Commit(false) s.state.CommitTo(s.db, false)
value = state.GetState(address, common.Hash{}) value = s.state.GetState(address, common.Hash{})
if !common.EmptyHash(value) { if !common.EmptyHash(value) {
t.Errorf("expected empty hash. got %x", value) c.Errorf("expected empty hash. got %x", value)
} }
} }
@ -129,17 +126,15 @@ func (s *StateSuite) TestSnapshot(c *checker.C) {
c.Assert(data1, checker.DeepEquals, res) c.Assert(data1, checker.DeepEquals, res)
} }
func TestSnapshotEmpty(t *testing.T) { func (s *StateSuite) TestSnapshotEmpty(c *checker.C) {
db, _ := ethdb.NewMemDatabase() s.state.RevertToSnapshot(s.state.Snapshot())
state, _ := New(common.Hash{}, db)
state.RevertToSnapshot(state.Snapshot())
} }
// use testing instead of checker because checker does not support // use testing instead of checker because checker does not support
// printing/logging in tests (-check.vv does not work) // printing/logging in tests (-check.vv does not work)
func TestSnapshot2(t *testing.T) { func TestSnapshot2(t *testing.T) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
state, _ := New(common.Hash{}, db) state, _ := New(common.Hash{}, NewDatabase(db))
stateobjaddr0 := toAddr([]byte("so0")) stateobjaddr0 := toAddr([]byte("so0"))
stateobjaddr1 := toAddr([]byte("so1")) stateobjaddr1 := toAddr([]byte("so1"))
@ -160,7 +155,7 @@ func TestSnapshot2(t *testing.T) {
so0.deleted = false so0.deleted = false
state.setStateObject(so0) state.setStateObject(so0)
root, _ := state.Commit(false) root, _ := state.CommitTo(db, false)
state.Reset(root) state.Reset(root)
// and one with deleted == true // and one with deleted == true
@ -182,8 +177,8 @@ func TestSnapshot2(t *testing.T) {
so0Restored := state.getStateObject(stateobjaddr0) so0Restored := state.getStateObject(stateobjaddr0)
// Update lazily-loaded values before comparing. // Update lazily-loaded values before comparing.
so0Restored.GetState(db, storageaddr) so0Restored.GetState(state.db, storageaddr)
so0Restored.Code(db) so0Restored.Code(state.db)
// non-deleted is equal (restored) // non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t) compareStateObjects(so0Restored, so0, t)

View File

@ -26,23 +26,9 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
lru "github.com/hashicorp/golang-lru"
)
// Trie cache generation limit after which to evic trie nodes from memory.
var MaxTrieCacheGen = uint16(120)
const (
// Number of past tries to keep. This value is chosen such that
// reasonable chain reorg depths will hit an existing trie.
maxPastTries = 12
// Number of codehash->size associations to keep.
codeSizeCacheSize = 100000
) )
type revision struct { type revision struct {
@ -56,16 +42,21 @@ type revision struct {
// * Contracts // * Contracts
// * Accounts // * Accounts
type StateDB struct { type StateDB struct {
db ethdb.Database db Database
trie *trie.SecureTrie trie Trie
pastTries []*trie.SecureTrie
codeSizeCache *lru.Cache
// This map holds 'live' objects, which will get modified while processing a state transition. // This map holds 'live' objects, which will get modified while processing a state transition.
stateObjects map[common.Address]*stateObject stateObjects map[common.Address]*stateObject
stateObjectsDirty map[common.Address]struct{} stateObjectsDirty map[common.Address]struct{}
stateObjectsDestructed map[common.Address]struct{} stateObjectsDestructed map[common.Address]struct{}
// DB error.
// State objects are used by the consensus core and VM which are
// unable to deal with database-level errors. Any error that occurs
// during a database read is memoized here and will eventually be returned
// by StateDB.Commit.
dbErr error
// The refund counter, also used by state transitioning. // The refund counter, also used by state transitioning.
refund *big.Int refund *big.Int
@ -86,16 +77,14 @@ type StateDB struct {
} }
// Create a new state from a given trie // Create a new state from a given trie
func New(root common.Hash, db ethdb.Database) (*StateDB, error) { func New(root common.Hash, db Database) (*StateDB, error) {
tr, err := trie.NewSecure(root, db, MaxTrieCacheGen) tr, err := db.OpenTrie(root)
if err != nil { if err != nil {
return nil, err return nil, err
} }
csc, _ := lru.New(codeSizeCacheSize)
return &StateDB{ return &StateDB{
db: db, db: db,
trie: tr, trie: tr,
codeSizeCache: csc,
stateObjects: make(map[common.Address]*stateObject), stateObjects: make(map[common.Address]*stateObject),
stateObjectsDirty: make(map[common.Address]struct{}), stateObjectsDirty: make(map[common.Address]struct{}),
stateObjectsDestructed: make(map[common.Address]struct{}), stateObjectsDestructed: make(map[common.Address]struct{}),
@ -105,36 +94,21 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
}, nil }, nil
} }
// New creates a new statedb by reusing any journalled tries to avoid costly // setError remembers the first non-nil error it is called with.
// disk io. func (self *StateDB) setError(err error) {
func (self *StateDB) New(root common.Hash) (*StateDB, error) { if self.dbErr == nil {
self.lock.Lock() self.dbErr = err
defer self.lock.Unlock()
tr, err := self.openTrie(root)
if err != nil {
return nil, err
} }
return &StateDB{ }
db: self.db,
trie: tr, func (self *StateDB) Error() error {
codeSizeCache: self.codeSizeCache, return self.dbErr
stateObjects: make(map[common.Address]*stateObject),
stateObjectsDirty: make(map[common.Address]struct{}),
stateObjectsDestructed: make(map[common.Address]struct{}),
refund: new(big.Int),
logs: make(map[common.Hash][]*types.Log),
preimages: make(map[common.Hash][]byte),
}, nil
} }
// Reset clears out all emphemeral state objects from the state db, but keeps // Reset clears out all emphemeral state objects from the state db, but keeps
// the underlying state trie to avoid reloading data for the next operations. // the underlying state trie to avoid reloading data for the next operations.
func (self *StateDB) Reset(root common.Hash) error { func (self *StateDB) Reset(root common.Hash) error {
self.lock.Lock() tr, err := self.db.OpenTrie(root)
defer self.lock.Unlock()
tr, err := self.openTrie(root)
if err != nil { if err != nil {
return err return err
} }
@ -149,34 +123,9 @@ func (self *StateDB) Reset(root common.Hash) error {
self.logSize = 0 self.logSize = 0
self.preimages = make(map[common.Hash][]byte) self.preimages = make(map[common.Hash][]byte)
self.clearJournalAndRefund() self.clearJournalAndRefund()
return nil return nil
} }
// openTrie creates a trie. It uses an existing trie if one is available
// from the journal if available.
func (self *StateDB) openTrie(root common.Hash) (*trie.SecureTrie, error) {
for i := len(self.pastTries) - 1; i >= 0; i-- {
if self.pastTries[i].Hash() == root {
tr := *self.pastTries[i]
return &tr, nil
}
}
return trie.NewSecure(root, self.db, MaxTrieCacheGen)
}
func (self *StateDB) pushTrie(t *trie.SecureTrie) {
self.lock.Lock()
defer self.lock.Unlock()
if len(self.pastTries) >= maxPastTries {
copy(self.pastTries, self.pastTries[1:])
self.pastTries[len(self.pastTries)-1] = t
} else {
self.pastTries = append(self.pastTries, t)
}
}
func (self *StateDB) AddLog(log *types.Log) { func (self *StateDB) AddLog(log *types.Log) {
self.journal = append(self.journal, addLogChange{txhash: self.thash}) self.journal = append(self.journal, addLogChange{txhash: self.thash})
@ -254,10 +203,7 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 {
func (self *StateDB) GetCode(addr common.Address) []byte { func (self *StateDB) GetCode(addr common.Address) []byte {
stateObject := self.getStateObject(addr) stateObject := self.getStateObject(addr)
if stateObject != nil { if stateObject != nil {
code := stateObject.Code(self.db) return stateObject.Code(self.db)
key := common.BytesToHash(stateObject.CodeHash())
self.codeSizeCache.Add(key, len(code))
return code
} }
return nil return nil
} }
@ -267,13 +213,12 @@ func (self *StateDB) GetCodeSize(addr common.Address) int {
if stateObject == nil { if stateObject == nil {
return 0 return 0
} }
key := common.BytesToHash(stateObject.CodeHash()) if stateObject.code != nil {
if cached, ok := self.codeSizeCache.Get(key); ok { return len(stateObject.code)
return cached.(int)
} }
size := len(stateObject.Code(self.db)) size, err := self.db.ContractCodeSize(stateObject.addrHash, common.BytesToHash(stateObject.CodeHash()))
if stateObject.dbErr == nil { if err != nil {
self.codeSizeCache.Add(key, size) self.setError(err)
} }
return size return size
} }
@ -296,7 +241,7 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
// StorageTrie returns the storage trie of an account. // StorageTrie returns the storage trie of an account.
// The return value is a copy and is nil for non-existent accounts. // The return value is a copy and is nil for non-existent accounts.
func (self *StateDB) StorageTrie(a common.Address) *trie.SecureTrie { func (self *StateDB) StorageTrie(a common.Address) Trie {
stateObject := self.getStateObject(a) stateObject := self.getStateObject(a)
if stateObject == nil { if stateObject == nil {
return nil return nil
@ -394,14 +339,14 @@ func (self *StateDB) updateStateObject(stateObject *stateObject) {
if err != nil { if err != nil {
panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err)) panic(fmt.Errorf("can't encode object at %x: %v", addr[:], err))
} }
self.trie.Update(addr[:], data) self.setError(self.trie.TryUpdate(addr[:], data))
} }
// deleteStateObject removes the given object from the state trie. // deleteStateObject removes the given object from the state trie.
func (self *StateDB) deleteStateObject(stateObject *stateObject) { func (self *StateDB) deleteStateObject(stateObject *stateObject) {
stateObject.deleted = true stateObject.deleted = true
addr := stateObject.Address() addr := stateObject.Address()
self.trie.Delete(addr[:]) self.setError(self.trie.TryDelete(addr[:]))
} }
// Retrieve a state object given my the address. Returns nil if not found. // Retrieve a state object given my the address. Returns nil if not found.
@ -415,8 +360,9 @@ func (self *StateDB) getStateObject(addr common.Address) (stateObject *stateObje
} }
// Load the object from the database. // Load the object from the database.
enc := self.trie.Get(addr[:]) enc, err := self.trie.TryGet(addr[:])
if len(enc) == 0 { if len(enc) == 0 {
self.setError(err)
return nil return nil
} }
var data Account var data Account
@ -512,8 +458,6 @@ func (self *StateDB) Copy() *StateDB {
state := &StateDB{ state := &StateDB{
db: self.db, db: self.db,
trie: self.trie, trie: self.trie,
pastTries: self.pastTries,
codeSizeCache: self.codeSizeCache,
stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)), stateObjects: make(map[common.Address]*stateObject, len(self.stateObjectsDirty)),
stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)), stateObjectsDirty: make(map[common.Address]struct{}, len(self.stateObjectsDirty)),
stateObjectsDestructed: make(map[common.Address]struct{}, len(self.stateObjectsDestructed)), stateObjectsDestructed: make(map[common.Address]struct{}, len(self.stateObjectsDestructed)),
@ -636,23 +580,6 @@ func (s *StateDB) DeleteSuicides() {
} }
} }
// Commit commits all state changes to the database.
func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) {
root, batch := s.CommitBatch(deleteEmptyObjects)
return root, batch.Write()
}
// CommitBatch commits all state changes to a write batch but does not
// execute the batch. It is used to validate state changes against
// the root hash stored in a block.
func (s *StateDB) CommitBatch(deleteEmptyObjects bool) (root common.Hash, batch ethdb.Batch) {
batch = s.db.NewBatch()
root, _ = s.CommitTo(batch, deleteEmptyObjects)
log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads())
return root, batch
}
func (s *StateDB) clearJournalAndRefund() { func (s *StateDB) clearJournalAndRefund() {
s.journal = nil s.journal = nil
s.validRevisions = s.validRevisions[:0] s.validRevisions = s.validRevisions[:0]
@ -690,8 +617,6 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
} }
// Write trie changes. // Write trie changes.
root, err = s.trie.CommitTo(dbw) root, err = s.trie.CommitTo(dbw)
if err == nil { log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads())
s.pushTrie(s.trie)
}
return root, err return root, err
} }

View File

@ -28,6 +28,8 @@ import (
"testing" "testing"
"testing/quick" "testing/quick"
check "gopkg.in/check.v1"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
@ -38,7 +40,7 @@ import (
func TestUpdateLeaks(t *testing.T) { func TestUpdateLeaks(t *testing.T) {
// Create an empty state database // Create an empty state database
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
state, _ := New(common.Hash{}, db) state, _ := New(common.Hash{}, NewDatabase(db))
// Update it with some accounts // Update it with some accounts
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
@ -66,8 +68,8 @@ func TestIntermediateLeaks(t *testing.T) {
// Create two state databases, one transitioning to the final state, the other final from the beginning // Create two state databases, one transitioning to the final state, the other final from the beginning
transDb, _ := ethdb.NewMemDatabase() transDb, _ := ethdb.NewMemDatabase()
finalDb, _ := ethdb.NewMemDatabase() finalDb, _ := ethdb.NewMemDatabase()
transState, _ := New(common.Hash{}, transDb) transState, _ := New(common.Hash{}, NewDatabase(transDb))
finalState, _ := New(common.Hash{}, finalDb) finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
modify := func(state *StateDB, addr common.Address, i, tweak byte) { modify := func(state *StateDB, addr common.Address, i, tweak byte) {
state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
@ -95,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) {
} }
// Commit and cross check the databases. // Commit and cross check the databases.
if _, err := transState.Commit(false); err != nil { if _, err := transState.CommitTo(transDb, false); err != nil {
t.Fatalf("failed to commit transition state: %v", err) t.Fatalf("failed to commit transition state: %v", err)
} }
if _, err := finalState.Commit(false); err != nil { if _, err := finalState.CommitTo(finalDb, false); err != nil {
t.Fatalf("failed to commit final state: %v", err) t.Fatalf("failed to commit final state: %v", err)
} }
for _, key := range finalDb.Keys() { for _, key := range finalDb.Keys() {
@ -282,7 +284,7 @@ func (test *snapshotTest) run() bool {
// Run all actions and create snapshots. // Run all actions and create snapshots.
var ( var (
db, _ = ethdb.NewMemDatabase() db, _ = ethdb.NewMemDatabase()
state, _ = New(common.Hash{}, db) state, _ = New(common.Hash{}, NewDatabase(db))
snapshotRevs = make([]int, len(test.snapshots)) snapshotRevs = make([]int, len(test.snapshots))
sindex = 0 sindex = 0
) )
@ -297,7 +299,7 @@ func (test *snapshotTest) run() bool {
// Revert all snapshots in reverse order. Each revert must yield a state // Revert all snapshots in reverse order. Each revert must yield a state
// that is equivalent to fresh state with all actions up the snapshot applied. // that is equivalent to fresh state with all actions up the snapshot applied.
for sindex--; sindex >= 0; sindex-- { for sindex--; sindex >= 0; sindex-- {
checkstate, _ := New(common.Hash{}, db) checkstate, _ := New(common.Hash{}, NewDatabase(db))
for _, action := range test.actions[:test.snapshots[sindex]] { for _, action := range test.actions[:test.snapshots[sindex]] {
action.fn(action, checkstate) action.fn(action, checkstate)
} }
@ -354,21 +356,19 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return nil return nil
} }
func TestTouchDelete(t *testing.T) { func (s *StateSuite) TestTouchDelete(c *check.C) {
db, _ := ethdb.NewMemDatabase() s.state.GetOrNewStateObject(common.Address{})
state, _ := New(common.Hash{}, db) root, _ := s.state.CommitTo(s.db, false)
state.GetOrNewStateObject(common.Address{}) s.state.Reset(root)
root, _ := state.Commit(false)
state.Reset(root)
snapshot := state.Snapshot() snapshot := s.state.Snapshot()
state.AddBalance(common.Address{}, new(big.Int)) s.state.AddBalance(common.Address{}, new(big.Int))
if len(state.stateObjectsDirty) != 1 { if len(s.state.stateObjectsDirty) != 1 {
t.Fatal("expected one dirty state object") c.Fatal("expected one dirty state object")
} }
state.RevertToSnapshot(snapshot) s.state.RevertToSnapshot(snapshot)
if len(state.stateObjectsDirty) != 0 { if len(s.state.stateObjectsDirty) != 0 {
t.Fatal("expected no dirty state object") c.Fatal("expected no dirty state object")
} }
} }

View File

@ -36,9 +36,10 @@ type testAccount struct {
} }
// makeTestState create a sample test state to test node-wise reconstruction. // makeTestState create a sample test state to test node-wise reconstruction.
func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) {
// Create an empty state // Create an empty state
db, _ := ethdb.NewMemDatabase() mem, _ := ethdb.NewMemDatabase()
db := NewDatabase(mem)
state, _ := New(common.Hash{}, db) state, _ := New(common.Hash{}, db)
// Fill it with some arbitrary data // Fill it with some arbitrary data
@ -60,17 +61,17 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) {
state.updateStateObject(obj) state.updateStateObject(obj)
accounts = append(accounts, acc) accounts = append(accounts, acc)
} }
root, _ := state.Commit(false) root, _ := state.CommitTo(mem, false)
// Return the generated state // Return the generated state
return db, root, accounts return db, mem, root, accounts
} }
// checkStateAccounts cross references a reconstructed state with an expected // checkStateAccounts cross references a reconstructed state with an expected
// account array. // account array.
func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) { func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accounts []*testAccount) {
// Check root availability and state contents // Check root availability and state contents
state, err := New(root, db) state, err := New(root, NewDatabase(db))
if err != nil { if err != nil {
t.Fatalf("failed to create state trie at %x: %v", root, err) t.Fatalf("failed to create state trie at %x: %v", root, err)
} }
@ -90,13 +91,28 @@ func checkStateAccounts(t *testing.T, db ethdb.Database, root common.Hash, accou
} }
} }
// checkStateConsistency checks that all nodes in a state trie are indeed present. // checkTrieConsistency checks that all nodes in a (sub-)trie are indeed present.
func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
if v, _ := db.Get(root[:]); v == nil {
return nil // Consider a non existent state consistent.
}
trie, err := trie.New(root, db)
if err != nil {
return err
}
it := trie.NodeIterator(nil)
for it.Next(true) {
}
return it.Error()
}
// checkStateConsistency checks that all data of a state root is present.
func checkStateConsistency(db ethdb.Database, root common.Hash) error { func checkStateConsistency(db ethdb.Database, root common.Hash) error {
// Create and iterate a state trie rooted in a sub-node // Create and iterate a state trie rooted in a sub-node
if _, err := db.Get(root.Bytes()); err != nil { if _, err := db.Get(root.Bytes()); err != nil {
return nil // Consider a non existent state consistent return nil // Consider a non existent state consistent.
} }
state, err := New(root, db) state, err := New(root, NewDatabase(db))
if err != nil { if err != nil {
return err return err
} }
@ -122,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T) { testIterativeStateSync(t,
func testIterativeStateSync(t *testing.T, batch int) { func testIterativeStateSync(t *testing.T, batch int) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -132,7 +148,7 @@ func testIterativeStateSync(t *testing.T, batch int) {
for len(queue) > 0 { for len(queue) > 0 {
results := make([]trie.SyncResult, len(queue)) results := make([]trie.SyncResult, len(queue))
for i, hash := range queue { for i, hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcMem.Get(hash.Bytes())
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -154,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) {
// partial results are returned, and the others sent only later. // partial results are returned, and the others sent only later.
func TestIterativeDelayedStateSync(t *testing.T) { func TestIterativeDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -165,7 +181,7 @@ func TestIterativeDelayedStateSync(t *testing.T) {
// Sync only half of the scheduled nodes // Sync only half of the scheduled nodes
results := make([]trie.SyncResult, len(queue)/2+1) results := make([]trie.SyncResult, len(queue)/2+1)
for i, hash := range queue[:len(results)] { for i, hash := range queue[:len(results)] {
data, err := srcDb.Get(hash.Bytes()) data, err := srcMem.Get(hash.Bytes())
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -191,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T) { testIterativeRandomS
func testIterativeRandomStateSync(t *testing.T, batch int) { func testIterativeRandomStateSync(t *testing.T, batch int) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -205,7 +221,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// Fetch all the queued nodes in a random order // Fetch all the queued nodes in a random order
results := make([]trie.SyncResult, 0, len(queue)) results := make([]trie.SyncResult, 0, len(queue))
for hash := range queue { for hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcMem.Get(hash.Bytes())
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -231,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
// partial results are returned (Even those randomly), others sent only later. // partial results are returned (Even those randomly), others sent only later.
func TestIterativeRandomDelayedStateSync(t *testing.T) { func TestIterativeRandomDelayedStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcMem, srcRoot, srcAccounts := makeTestState()
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -247,7 +263,7 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
for hash := range queue { for hash := range queue {
delete(queue, hash) delete(queue, hash)
data, err := srcDb.Get(hash.Bytes()) data, err := srcMem.Get(hash.Bytes())
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -276,7 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
// the database. // the database.
func TestIncompleteStateSync(t *testing.T) { func TestIncompleteStateSync(t *testing.T) {
// Create a random state to copy // Create a random state to copy
srcDb, srcRoot, srcAccounts := makeTestState() _, srcMem, srcRoot, srcAccounts := makeTestState()
checkTrieConsistency(srcMem, srcRoot)
// Create a destination state and sync with the scheduler // Create a destination state and sync with the scheduler
dstDb, _ := ethdb.NewMemDatabase() dstDb, _ := ethdb.NewMemDatabase()
@ -288,7 +306,7 @@ func TestIncompleteStateSync(t *testing.T) {
// Fetch a batch of state nodes // Fetch a batch of state nodes
results := make([]trie.SyncResult, len(queue)) results := make([]trie.SyncResult, len(queue))
for i, hash := range queue { for i, hash := range queue {
data, err := srcDb.Get(hash.Bytes()) data, err := srcMem.Get(hash.Bytes())
if err != nil { if err != nil {
t.Fatalf("failed to retrieve node data for %x: %v", hash, err) t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
} }
@ -304,23 +322,20 @@ func TestIncompleteStateSync(t *testing.T) {
for _, result := range results { for _, result := range results {
added = append(added, result.Hash) added = append(added, result.Hash)
} }
// Check that all known sub-tries in the synced state is complete // Check that all known sub-tries added so far are complete or missing entirely.
for _, root := range added { checkSubtries:
// Skim through the accounts and make sure the root hash is not a code node for _, hash := range added {
codeHash := false
for _, acc := range srcAccounts { for _, acc := range srcAccounts {
if root == crypto.Keccak256Hash(acc.code) { if hash == crypto.Keccak256Hash(acc.code) {
codeHash = true continue checkSubtries // skip trie check of code nodes.
break
} }
} }
// If the root is a real trie node, check consistency // Can't use checkStateConsistency here because subtrie keys may have odd
if !codeHash { // length and crash in LeafKey.
if err := checkStateConsistency(dstDb, root); err != nil { if err := checkTrieConsistency(dstDb, hash); err != nil {
t.Fatalf("state inconsistent: %v", err) t.Fatalf("state inconsistent: %v", err)
} }
} }
}
// Fetch the next batch to retrieve // Fetch the next batch to retrieve
queue = append(queue[:0], sched.Missing(1)...) queue = append(queue[:0], sched.Missing(1)...)
} }

View File

@ -44,7 +44,7 @@ func pricedTransaction(nonce uint64, gaslimit, gasprice *big.Int, key *ecdsa.Pri
func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { func setupTxPool() (*TxPool, *ecdsa.PrivateKey) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
key, _ := crypto.GenerateKey() key, _ := crypto.GenerateKey()
newPool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) newPool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
@ -95,7 +95,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) {
key, _ = crypto.GenerateKey() key, _ = crypto.GenerateKey()
address = crypto.PubkeyToAddress(key.PublicKey) address = crypto.PubkeyToAddress(key.PublicKey)
mux = new(event.TypeMux) mux = new(event.TypeMux)
statedb, _ = state.New(common.Hash{}, db) statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
trigger = false trigger = false
) )
@ -114,7 +114,7 @@ func TestStateChangeDuringPoolReset(t *testing.T) {
// a state change between those fetches. // a state change between those fetches.
stdb := statedb stdb := statedb
if trigger { if trigger {
statedb, _ = state.New(common.Hash{}, db) statedb, _ = state.New(common.Hash{}, state.NewDatabase(db))
// simulate that the new head block included tx0 and tx1 // simulate that the new head block included tx0 and tx1
statedb.SetNonce(address, 2) statedb.SetNonce(address, 2)
statedb.SetBalance(address, new(big.Int).SetUint64(params.Ether)) statedb.SetBalance(address, new(big.Int).SetUint64(params.Ether))
@ -292,7 +292,7 @@ func TestTransactionChainFork(t *testing.T) {
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() { resetState := func() {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool.currentState = func() (*state.StateDB, error) { return statedb, nil } pool.currentState = func() (*state.StateDB, error) { return statedb, nil }
currentState, _ := pool.currentState() currentState, _ := pool.currentState()
currentState.AddBalance(addr, big.NewInt(100000000000000)) currentState.AddBalance(addr, big.NewInt(100000000000000))
@ -318,7 +318,7 @@ func TestTransactionDoubleNonce(t *testing.T) {
addr := crypto.PubkeyToAddress(key.PublicKey) addr := crypto.PubkeyToAddress(key.PublicKey)
resetState := func() { resetState := func() {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool.currentState = func() (*state.StateDB, error) { return statedb, nil } pool.currentState = func() (*state.StateDB, error) { return statedb, nil }
currentState, _ := pool.currentState() currentState, _ := pool.currentState()
currentState.AddBalance(addr, big.NewInt(100000000000000)) currentState.AddBalance(addr, big.NewInt(100000000000000))
@ -628,7 +628,7 @@ func TestTransactionQueueGlobalLimiting(t *testing.T) {
// Create the pool to test the limit enforcement with // Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState() pool.resetState()
@ -783,7 +783,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) {
// Create the pool to test the limit enforcement with // Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState() pool.resetState()
@ -835,7 +835,7 @@ func TestTransactionCapClearsFromAll(t *testing.T) {
// Create the pool to test the limit enforcement with // Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState() pool.resetState()
@ -868,7 +868,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) {
// Create the pool to test the limit enforcement with // Create the pool to test the limit enforcement with
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState() pool.resetState()
@ -913,7 +913,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) {
func TestTransactionPoolRepricing(t *testing.T) { func TestTransactionPoolRepricing(t *testing.T) {
// Create the pool to test the pricing enforcement with // Create the pool to test the pricing enforcement with
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState() pool.resetState()
@ -1006,7 +1006,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) {
// Create the pool to test the pricing enforcement with // Create the pool to test the pricing enforcement with
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState() pool.resetState()
@ -1091,7 +1091,7 @@ func TestTransactionPoolUnderpricing(t *testing.T) {
func TestTransactionReplacement(t *testing.T) { func TestTransactionReplacement(t *testing.T) {
// Create the pool to test the pricing enforcement with // Create the pool to test the pricing enforcement with
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
statedb, _ := state.New(common.Hash{}, db) statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) }) pool := NewTxPool(DefaultTxPoolConfig, params.TestChainConfig, new(event.TypeMux), func() (*state.StateDB, error) { return statedb, nil }, func() *big.Int { return big.NewInt(1000000) })
pool.resetState() pool.resetState()

View File

@ -102,7 +102,7 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) {
if cfg.State == nil { if cfg.State == nil {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
cfg.State, _ = state.New(common.Hash{}, db) cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db))
} }
var ( var (
address = common.StringToAddress("contract") address = common.StringToAddress("contract")
@ -133,7 +133,7 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) {
if cfg.State == nil { if cfg.State == nil {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
cfg.State, _ = state.New(common.Hash{}, db) cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(db))
} }
var ( var (
vmenv = NewEnv(cfg, cfg.State) vmenv = NewEnv(cfg, cfg.State)

View File

@ -95,7 +95,7 @@ func TestExecute(t *testing.T) {
func TestCall(t *testing.T) { func TestCall(t *testing.T) {
db, _ := ethdb.NewMemDatabase() db, _ := ethdb.NewMemDatabase()
state, _ := state.New(common.Hash{}, db) state, _ := state.New(common.Hash{}, state.NewDatabase(db))
address := common.HexToAddress("0x0a") address := common.HexToAddress("0x0a")
state.SetCode(address, []byte{ state.SetCode(address, []byte{
byte(vm.PUSH1), 10, byte(vm.PUSH1), 10,

View File

@ -637,7 +637,7 @@ func (api *PrivateDebugAPI) StorageRangeAt(ctx context.Context, blockHash common
return storageRangeAt(st, keyStart, maxResult), nil return storageRangeAt(st, keyStart, maxResult), nil
} }
func storageRangeAt(st *trie.SecureTrie, start []byte, maxResult int) StorageRangeResult { func storageRangeAt(st state.Trie, start []byte, maxResult int) StorageRangeResult {
it := trie.NewIterator(st.NodeIterator(start)) it := trie.NewIterator(st.NodeIterator(start))
result := StorageRangeResult{Storage: storageMap{}} result := StorageRangeResult{Storage: storageMap{}}
for i := 0; i < maxResult && it.Next(); i++ { for i := 0; i < maxResult && it.Next(); i++ {

View File

@ -31,7 +31,6 @@ import (
"github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
@ -81,11 +80,11 @@ func (b *EthApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb
return b.eth.blockchain.GetBlockByNumber(uint64(blockNr)), nil return b.eth.blockchain.GetBlockByNumber(uint64(blockNr)), nil
} }
func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) { func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) {
// Pending state is only known by the miner // Pending state is only known by the miner
if blockNr == rpc.PendingBlockNumber { if blockNr == rpc.PendingBlockNumber {
block, state := b.eth.miner.Pending() block, state := b.eth.miner.Pending()
return EthApiState{state}, block.Header(), nil return state, block.Header(), nil
} }
// Otherwise resolve the block number and return its state // Otherwise resolve the block number and return its state
header, err := b.HeaderByNumber(ctx, blockNr) header, err := b.HeaderByNumber(ctx, blockNr)
@ -93,7 +92,7 @@ func (b *EthApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.
return nil, nil, err return nil, nil, err
} }
stateDb, err := b.eth.BlockChain().StateAt(header.Root) stateDb, err := b.eth.BlockChain().StateAt(header.Root)
return EthApiState{stateDb}, header, err return stateDb, header, err
} }
func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) { func (b *EthApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) {
@ -108,14 +107,12 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int {
return b.eth.blockchain.GetTdByHash(blockHash) return b.eth.blockchain.GetTdByHash(blockHash)
} }
func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
statedb := state.(EthApiState).state state.SetBalance(msg.From(), math.MaxBig256)
from := statedb.GetOrNewStateObject(msg.From())
from.SetBalance(math.MaxBig256)
vmError := func() error { return nil } vmError := func() error { return nil }
context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil)
return vm.NewEVM(context, statedb, b.eth.chainConfig, vmCfg), vmError, nil return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), vmError, nil
} }
func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { func (b *EthApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {
@ -200,23 +197,3 @@ func (b *EthApiBackend) EventMux() *event.TypeMux {
func (b *EthApiBackend) AccountManager() *accounts.Manager { func (b *EthApiBackend) AccountManager() *accounts.Manager {
return b.eth.AccountManager() return b.eth.AccountManager()
} }
type EthApiState struct {
state *state.StateDB
}
func (s EthApiState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) {
return s.state.GetBalance(addr), nil
}
func (s EthApiState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) {
return s.state.GetCode(addr), nil
}
func (s EthApiState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) {
return s.state.GetState(a, b), nil
}
func (s EthApiState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) {
return s.state.GetNonce(addr), nil
}

View File

@ -32,7 +32,7 @@ func TestStorageRangeAt(t *testing.T) {
// Create a state where account 0x010000... has a few storage entries. // Create a state where account 0x010000... has a few storage entries.
var ( var (
db, _ = ethdb.NewMemDatabase() db, _ = ethdb.NewMemDatabase()
state, _ = state.New(common.Hash{}, db) state, _ = state.New(common.Hash{}, state.NewDatabase(db))
addr = common.Address{0x01} addr = common.Address{0x01}
keys = []common.Hash{ // hashes of Keys of storage keys = []common.Hash{ // hashes of Keys of storage
common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"), common.HexToHash("340dd630ad21bf010b4e676dbfa9ba9a02175262d1fa356232cfde6cb5b47ef2"),

View File

@ -54,14 +54,12 @@ func NewContractBackend(apiBackend ethapi.Backend) *ContractBackend {
// CodeAt retrieves any code associated with the contract from the local API. // CodeAt retrieves any code associated with the contract from the local API.
func (b *ContractBackend) CodeAt(ctx context.Context, contract common.Address, blockNum *big.Int) ([]byte, error) { func (b *ContractBackend) CodeAt(ctx context.Context, contract common.Address, blockNum *big.Int) ([]byte, error) {
out, err := b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum)) return b.bcapi.GetCode(ctx, contract, toBlockNumber(blockNum))
return common.FromHex(out), err
} }
// CodeAt retrieves any code associated with the contract from the local API. // CodeAt retrieves any code associated with the contract from the local API.
func (b *ContractBackend) PendingCodeAt(ctx context.Context, contract common.Address) ([]byte, error) { func (b *ContractBackend) PendingCodeAt(ctx context.Context, contract common.Address) ([]byte, error) {
out, err := b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber) return b.bcapi.GetCode(ctx, contract, rpc.PendingBlockNumber)
return common.FromHex(out), err
} }
// ContractCall implements bind.ContractCaller executing an Ethereum contract // ContractCall implements bind.ContractCaller executing an Ethereum contract

View File

@ -657,7 +657,7 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot) index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot)
} }
if index > 0 { if index > 0 {
if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, tester.stateDb); statedb == nil || err != nil { if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil {
t.Fatalf("state reconstruction failed: %v", err) t.Fatalf("state reconstruction failed: %v", err)
} }
} }

View File

@ -374,7 +374,7 @@ func testGetNodeData(t *testing.T, protocol int) {
} }
accounts := []common.Address{testBank, acc1Addr, acc2Addr} accounts := []common.Address{testBank, acc1Addr, acc2Addr}
for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ { for i := uint64(0); i <= pm.blockchain.CurrentBlock().NumberU64(); i++ {
trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), statedb) trie, _ := state.New(pm.blockchain.GetBlockByNumber(i).Root(), state.NewDatabase(statedb))
for j, acc := range accounts { for j, acc := range accounts {
state, _ := pm.blockchain.State() state, _ := pm.blockchain.State()

View File

@ -447,8 +447,8 @@ func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Add
if state == nil || err != nil { if state == nil || err != nil {
return nil, err return nil, err
} }
b := state.GetBalance(address)
return state.GetBalance(ctx, address) return b, state.Error()
} }
// GetBlockByNumber returns the requested block. When blockNr is -1 the chain head is returned. When fullTx is true all // GetBlockByNumber returns the requested block. When blockNr is -1 the chain head is returned. When fullTx is true all
@ -529,31 +529,25 @@ func (s *PublicBlockChainAPI) GetUncleCountByBlockHash(ctx context.Context, bloc
} }
// GetCode returns the code stored at the given address in the state for the given block number. // GetCode returns the code stored at the given address in the state for the given block number.
func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (string, error) { func (s *PublicBlockChainAPI) GetCode(ctx context.Context, address common.Address, blockNr rpc.BlockNumber) (hexutil.Bytes, error) {
state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr)
if state == nil || err != nil { if state == nil || err != nil {
return "", err return nil, err
} }
res, err := state.GetCode(ctx, address) code := state.GetCode(address)
if len(res) == 0 || err != nil { // backwards compatibility return code, state.Error()
return "0x", err
}
return common.ToHex(res), nil
} }
// GetStorageAt returns the storage from the state at the given address, key and // GetStorageAt returns the storage from the state at the given address, key and
// block number. The rpc.LatestBlockNumber and rpc.PendingBlockNumber meta block // block number. The rpc.LatestBlockNumber and rpc.PendingBlockNumber meta block
// numbers are also allowed. // numbers are also allowed.
func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (string, error) { func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.Address, key string, blockNr rpc.BlockNumber) (hexutil.Bytes, error) {
state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr) state, _, err := s.b.StateAndHeaderByNumber(ctx, blockNr)
if state == nil || err != nil { if state == nil || err != nil {
return "0x", err return nil, err
} }
res, err := state.GetState(ctx, address, common.HexToHash(key)) res := state.GetState(address, common.HexToHash(key))
if err != nil { return res[:], state.Error()
return "0x", err
}
return res.Hex(), nil
} }
// callmsg is the message type used for call transitions. // callmsg is the message type used for call transitions.
@ -978,11 +972,8 @@ func (s *PublicTransactionPoolAPI) GetTransactionCount(ctx context.Context, addr
if state == nil || err != nil { if state == nil || err != nil {
return nil, err return nil, err
} }
nonce, err := state.GetNonce(ctx, address) nonce := state.GetNonce(address)
if err != nil { return (*hexutil.Uint64)(&nonce), state.Error()
return nil, err
}
return (*hexutil.Uint64)(&nonce), nil
} }
// getTransactionBlockData fetches the meta data for the given transaction from the chain database. This is useful to // getTransactionBlockData fetches the meta data for the given transaction from the chain database. This is useful to

View File

@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
@ -47,11 +48,12 @@ type Backend interface {
SetHead(number uint64) SetHead(number uint64)
HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error) HeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Header, error)
BlockByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Block, error) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*types.Block, error)
StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (State, *types.Header, error) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error)
GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error)
GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error)
GetTd(blockHash common.Hash) *big.Int GetTd(blockHash common.Hash) *big.Int
GetEVM(ctx context.Context, msg core.Message, state State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error)
// TxPool API // TxPool API
SendTx(ctx context.Context, signedTx *types.Transaction) error SendTx(ctx context.Context, signedTx *types.Transaction) error
RemoveTx(txHash common.Hash) RemoveTx(txHash common.Hash)
@ -65,13 +67,6 @@ type Backend interface {
CurrentBlock() *types.Block CurrentBlock() *types.Block
} }
type State interface {
GetBalance(ctx context.Context, addr common.Address) (*big.Int, error)
GetCode(ctx context.Context, addr common.Address) ([]byte, error)
GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error)
GetNonce(ctx context.Context, addr common.Address) (uint64, error)
}
func GetAPIs(apiBackend Backend) []rpc.API { func GetAPIs(apiBackend Backend) []rpc.API {
nonceLock := new(AddrLocker) nonceLock := new(AddrLocker)
return []rpc.API{ return []rpc.API{

View File

@ -24,13 +24,13 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/downloader" "github.com/ethereum/go-ethereum/eth/downloader"
"github.com/ethereum/go-ethereum/eth/gasprice" "github.com/ethereum/go-ethereum/eth/gasprice"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/light" "github.com/ethereum/go-ethereum/light"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
@ -70,12 +70,12 @@ func (b *LesApiBackend) BlockByNumber(ctx context.Context, blockNr rpc.BlockNumb
return b.GetBlock(ctx, header.Hash()) return b.GetBlock(ctx, header.Hash())
} }
func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (ethapi.State, *types.Header, error) { func (b *LesApiBackend) StateAndHeaderByNumber(ctx context.Context, blockNr rpc.BlockNumber) (*state.StateDB, *types.Header, error) {
header, err := b.HeaderByNumber(ctx, blockNr) header, err := b.HeaderByNumber(ctx, blockNr)
if header == nil || err != nil { if header == nil || err != nil {
return nil, nil, err return nil, nil, err
} }
return light.NewLightState(light.StateTrieID(header), b.eth.odr), header, nil return light.NewState(ctx, header, b.eth.odr), header, nil
} }
func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) { func (b *LesApiBackend) GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) {
@ -90,18 +90,10 @@ func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int {
return b.eth.blockchain.GetTdByHash(blockHash) return b.eth.blockchain.GetTdByHash(blockHash)
} }
func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state ethapi.State, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) {
stateDb := state.(*light.LightState).Copy() state.SetBalance(msg.From(), math.MaxBig256)
addr := msg.From()
from, err := stateDb.GetOrNewStateObject(ctx, addr)
if err != nil {
return nil, nil, err
}
from.SetBalance(math.MaxBig256)
vmstate := light.NewVMState(ctx, stateDb)
context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) context := core.NewEVMContext(msg, header, b.eth.blockchain, nil)
return vm.NewEVM(context, vmstate, b.eth.chainConfig, vmCfg), vmstate.Error, nil return vm.NewEVM(context, state, b.eth.chainConfig, vmCfg), state.Error, nil
} }
func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error {

View File

@ -75,25 +75,24 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr}
var res []byte var (
res []byte
st *state.StateDB
err error
)
for _, addr := range acc { for _, addr := range acc {
if bc != nil { if bc != nil {
header := bc.GetHeaderByHash(bhash) header := bc.GetHeaderByHash(bhash)
st, err := state.New(header.Root, db) st, err = state.New(header.Root, state.NewDatabase(db))
} else {
header := lc.GetHeaderByHash(bhash)
st = light.NewState(ctx, header, lc.Odr())
}
if err == nil { if err == nil {
bal := st.GetBalance(addr) bal := st.GetBalance(addr)
rlp, _ := rlp.EncodeToBytes(bal) rlp, _ := rlp.EncodeToBytes(bal)
res = append(res, rlp...) res = append(res, rlp...)
} }
} else {
header := lc.GetHeaderByHash(bhash)
st := light.NewLightState(light.StateTrieID(header), lc.Odr())
bal, err := st.GetBalance(ctx, addr)
if err == nil {
rlp, _ := rlp.EncodeToBytes(bal)
res = append(res, rlp...)
}
}
} }
return res return res
@ -115,7 +114,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
data[35] = byte(i) data[35] = byte(i)
if bc != nil { if bc != nil {
header := bc.GetHeaderByHash(bhash) header := bc.GetHeaderByHash(bhash)
statedb, err := state.New(header.Root, db) statedb, err := state.New(header.Root, state.NewDatabase(db))
if err == nil { if err == nil {
from := statedb.GetOrNewStateObject(testBankAddress) from := statedb.GetOrNewStateObject(testBankAddress)
@ -133,26 +132,18 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai
} }
} else { } else {
header := lc.GetHeaderByHash(bhash) header := lc.GetHeaderByHash(bhash)
state := light.NewLightState(light.StateTrieID(header), lc.Odr()) state := light.NewState(ctx, header, lc.Odr())
vmstate := light.NewVMState(ctx, state) state.SetBalance(testBankAddress, math.MaxBig256)
from, err := state.GetOrNewStateObject(ctx, testBankAddress) msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)}
if err == nil {
from.SetBalance(math.MaxBig256)
msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(100000), new(big.Int), data, false)}
context := core.NewEVMContext(msg, header, lc, nil) context := core.NewEVMContext(msg, header, lc, nil)
vmenv := vm.NewEVM(context, vmstate, config, vm.Config{}) vmenv := vm.NewEVM(context, state, config, vm.Config{})
//vmenv := light.NewEnv(ctx, state, config, lc, msg, header, vm.Config{})
gp := new(core.GasPool).AddGas(math.MaxBig256) gp := new(core.GasPool).AddGas(math.MaxBig256)
ret, _, _ := core.ApplyMessage(vmenv, msg, gp) ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
if vmstate.Error() == nil { if state.Error() == nil {
res = append(res, ret...) res = append(res, ret...)
} }
} }
} }
}
return res return res
} }

View File

@ -62,7 +62,7 @@ func tfCodeAccess(db ethdb.Database, bhash common.Hash, number uint64) light.Odr
return nil return nil
} }
sti := light.StateTrieID(header) sti := light.StateTrieID(header)
ci := light.StorageTrieID(sti, testContractAddr, common.Hash{}) ci := light.StorageTrieID(sti, crypto.Keccak256Hash(testContractAddr[:]), common.Hash{})
return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)} return &light.CodeRequest{Id: ci, Hash: crypto.Keccak256Hash(testContractCodeDeployed)}
} }

View File

@ -180,11 +180,6 @@ func (self *LightChain) Status() (td *big.Int, currentBlock common.Hash, genesis
return self.GetTd(hash, header.Number.Uint64()), hash, self.genesisBlock.Hash() return self.GetTd(hash, header.Number.Uint64()), hash, self.genesisBlock.Hash()
} }
// State returns a new mutable state based on the current HEAD block.
func (self *LightChain) State() *LightState {
return NewLightState(StateTrieID(self.hc.CurrentHeader()), self.odr)
}
// Reset purges the entire blockchain, restoring it to its genesis state. // Reset purges the entire blockchain, restoring it to its genesis state.
func (bc *LightChain) Reset() { func (bc *LightChain) Reset() {
bc.ResetWithGenesisBlock(bc.genesisBlock) bc.ResetWithGenesisBlock(bc.genesisBlock)

View File

@ -34,7 +34,7 @@ import (
// service is not required. // service is not required.
var NoOdr = context.Background() var NoOdr = context.Background()
// OdrBackend is an interface to a backend service that handles ODR retrievals // OdrBackend is an interface to a backend service that handles ODR retrievals type
type OdrBackend interface { type OdrBackend interface {
Database() ethdb.Database Database() ethdb.Database
Retrieve(ctx context.Context, req OdrRequest) error Retrieve(ctx context.Context, req OdrRequest) error
@ -66,11 +66,11 @@ func StateTrieID(header *types.Header) *TrieID {
// StorageTrieID returns a TrieID for a contract storage trie at a given account // StorageTrieID returns a TrieID for a contract storage trie at a given account
// of a given state trie. It also requires the root hash of the trie for // of a given state trie. It also requires the root hash of the trie for
// checking Merkle proofs. // checking Merkle proofs.
func StorageTrieID(state *TrieID, addr common.Address, root common.Hash) *TrieID { func StorageTrieID(state *TrieID, addrHash, root common.Hash) *TrieID {
return &TrieID{ return &TrieID{
BlockHash: state.BlockHash, BlockHash: state.BlockHash,
BlockNumber: state.BlockNumber, BlockNumber: state.BlockNumber,
AccKey: crypto.Keccak256(addr[:]), AccKey: addrHash[:],
Root: root, Root: root,
} }
} }
@ -102,7 +102,7 @@ func storeProof(db ethdb.Database, proof []rlp.RawValue) {
// CodeRequest is the ODR request type for retrieving contract code // CodeRequest is the ODR request type for retrieving contract code
type CodeRequest struct { type CodeRequest struct {
OdrRequest OdrRequest
Id *TrieID Id *TrieID // references storage trie of the account
Hash common.Hash Hash common.Hash
Data []byte Data []byte
} }

View File

@ -86,11 +86,11 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
return nil return nil
} }
type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte type odrTestFn func(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error)
func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetBlock) } func TestOdrGetBlockLes1(t *testing.T) { testChainOdr(t, 1, odrGetBlock) }
func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
var block *types.Block var block *types.Block
if bc != nil { if bc != nil {
block = bc.GetBlockByHash(bhash) block = bc.GetBlockByHash(bhash)
@ -98,15 +98,15 @@ func odrGetBlock(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc
block, _ = lc.GetBlockByHash(ctx, bhash) block, _ = lc.GetBlockByHash(ctx, bhash)
} }
if block == nil { if block == nil {
return nil return nil, nil
} }
rlp, _ := rlp.EncodeToBytes(block) rlp, _ := rlp.EncodeToBytes(block)
return rlp return rlp, nil
} }
func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrGetReceipts) } func TestOdrGetReceiptsLes1(t *testing.T) { testChainOdr(t, 1, odrGetReceipts) }
func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
var receipts types.Receipts var receipts types.Receipts
if bc != nil { if bc != nil {
receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash)) receipts = core.GetBlockReceipts(db, bhash, core.GetBlockNumber(db, bhash))
@ -114,43 +114,37 @@ func odrGetReceipts(ctx context.Context, db ethdb.Database, bc *core.BlockChain,
receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash)) receipts, _ = GetBlockReceipts(ctx, lc.Odr(), bhash, core.GetBlockNumber(db, bhash))
} }
if receipts == nil { if receipts == nil {
return nil return nil, nil
} }
rlp, _ := rlp.EncodeToBytes(receipts) rlp, _ := rlp.EncodeToBytes(receipts)
return rlp return rlp, nil
} }
func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, 1, odrAccounts) } func TestOdrAccountsLes1(t *testing.T) { testChainOdr(t, 1, odrAccounts) }
func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { func odrAccounts(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678") dummyAddr := common.HexToAddress("1234567812345678123456781234567812345678")
acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr} acc := []common.Address{testBankAddress, acc1Addr, acc2Addr, dummyAddr}
var st *state.StateDB
if bc == nil {
header := lc.GetHeaderByHash(bhash)
st = NewState(ctx, header, lc.Odr())
} else {
header := bc.GetHeaderByHash(bhash)
st, _ = state.New(header.Root, state.NewDatabase(db))
}
var res []byte var res []byte
for _, addr := range acc { for _, addr := range acc {
if bc != nil {
header := bc.GetHeaderByHash(bhash)
st, err := state.New(header.Root, db)
if err == nil {
bal := st.GetBalance(addr) bal := st.GetBalance(addr)
rlp, _ := rlp.EncodeToBytes(bal) rlp, _ := rlp.EncodeToBytes(bal)
res = append(res, rlp...) res = append(res, rlp...)
} }
} else { return res, st.Error()
header := lc.GetHeaderByHash(bhash)
st := NewLightState(StateTrieID(header), lc.Odr())
bal, err := st.GetBalance(ctx, addr)
if err == nil {
rlp, _ := rlp.EncodeToBytes(bal)
res = append(res, rlp...)
}
}
}
return res
} }
func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, 2, odrContractCall) } func TestOdrContractCallLes1(t *testing.T) { testChainOdr(t, 1, odrContractCall) }
type callmsg struct { type callmsg struct {
types.Message types.Message
@ -158,50 +152,42 @@ type callmsg struct {
func (callmsg) CheckNonce() bool { return false } func (callmsg) CheckNonce() bool { return false }
func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) []byte { func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain, lc *LightChain, bhash common.Hash) ([]byte, error) {
data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000") data := common.Hex2Bytes("60CD26850000000000000000000000000000000000000000000000000000000000000000")
config := params.TestChainConfig config := params.TestChainConfig
var res []byte var res []byte
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
data[35] = byte(i) data[35] = byte(i)
if bc != nil {
header := bc.GetHeaderByHash(bhash)
statedb, err := state.New(header.Root, db)
if err == nil {
from := statedb.GetOrNewStateObject(testBankAddress)
from.SetBalance(math.MaxBig256)
msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} var (
st *state.StateDB
context := core.NewEVMContext(msg, header, bc, nil) header *types.Header
vmenv := vm.NewEVM(context, statedb, config, vm.Config{}) chain core.ChainContext
)
gp := new(core.GasPool).AddGas(math.MaxBig256) if bc == nil {
ret, _, _ := core.ApplyMessage(vmenv, msg, gp) chain = lc
res = append(res, ret...) header = lc.GetHeaderByHash(bhash)
} st = NewState(ctx, header, lc.Odr())
} else { } else {
header := lc.GetHeaderByHash(bhash) chain = bc
state := NewLightState(StateTrieID(header), lc.Odr()) header = bc.GetHeaderByHash(bhash)
vmstate := NewVMState(ctx, state) st, _ = state.New(header.Root, state.NewDatabase(db))
from, err := state.GetOrNewStateObject(ctx, testBankAddress) }
if err == nil {
from.SetBalance(math.MaxBig256)
msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)} // Perform read-only call.
context := core.NewEVMContext(msg, header, lc, nil) st.SetBalance(testBankAddress, math.MaxBig256)
vmenv := vm.NewEVM(context, vmstate, config, vm.Config{}) msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), big.NewInt(1000000), new(big.Int), data, false)}
context := core.NewEVMContext(msg, header, chain, nil)
vmenv := vm.NewEVM(context, st, config, vm.Config{})
gp := new(core.GasPool).AddGas(math.MaxBig256) gp := new(core.GasPool).AddGas(math.MaxBig256)
ret, _, _ := core.ApplyMessage(vmenv, msg, gp) ret, _, _ := core.ApplyMessage(vmenv, msg, gp)
if vmstate.Error() == nil {
res = append(res, ret...) res = append(res, ret...)
if st.Error() != nil {
return res, st.Error()
} }
} }
} return res, nil
}
return res
} }
func testChainGen(i int, block *core.BlockGen) { func testChainGen(i int, block *core.BlockGen) {
@ -245,7 +231,7 @@ func testChainGen(i int, block *core.BlockGen) {
} }
} }
func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) { func testChainOdr(t *testing.T, protocol int, fn odrTestFn) {
var ( var (
evmux = new(event.TypeMux) evmux = new(event.TypeMux)
sdb, _ = ethdb.NewMemDatabase() sdb, _ = ethdb.NewMemDatabase()
@ -258,46 +244,58 @@ func testChainOdr(t *testing.T, protocol int, expFail uint64, fn odrTestFn) {
blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), evmux, vm.Config{}) blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), evmux, vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, sdb, 4, testChainGen) gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, sdb, 4, testChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil { if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err) t.Fatal(err)
} }
odr := &testOdr{sdb: sdb, ldb: ldb} odr := &testOdr{sdb: sdb, ldb: ldb}
lightchain, _ := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux) lightchain, err := NewLightChain(odr, params.TestChainConfig, ethash.NewFullFaker(), evmux)
if err != nil {
t.Fatal(err)
}
headers := make([]*types.Header, len(gchain)) headers := make([]*types.Header, len(gchain))
for i, block := range gchain { for i, block := range gchain {
headers[i] = block.Header() headers[i] = block.Header()
} }
if _, err := lightchain.InsertHeaderChain(headers, 1); err != nil { if _, err := lightchain.InsertHeaderChain(headers, 1); err != nil {
panic(err) t.Fatal(err)
} }
test := func(expFail uint64) { test := func(expFail int) {
for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ { for i := uint64(0); i <= blockchain.CurrentHeader().Number.Uint64(); i++ {
bhash := core.GetCanonicalHash(sdb, i) bhash := core.GetCanonicalHash(sdb, i)
b1 := fn(NoOdr, sdb, blockchain, nil, bhash) b1, err := fn(NoOdr, sdb, blockchain, nil, bhash)
if err != nil {
t.Fatalf("error in full-node test for block %d: %v", i, err)
}
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel() defer cancel()
b2 := fn(ctx, ldb, nil, lightchain, bhash)
exp := i < uint64(expFail)
b2, err := fn(ctx, ldb, nil, lightchain, bhash)
if err != nil && exp {
t.Errorf("error in ODR test for block %d: %v", i, err)
}
eq := bytes.Equal(b1, b2) eq := bytes.Equal(b1, b2)
exp := i < expFail
if exp && !eq { if exp && !eq {
t.Errorf("odr mismatch") t.Errorf("ODR test output for block %d doesn't match full node", i)
}
if !exp && eq {
t.Errorf("unexpected odr match")
} }
} }
} }
odr.disable = true
// expect retrievals to fail (except genesis block) without a les peer // expect retrievals to fail (except genesis block) without a les peer
test(expFail) t.Log("checking without ODR")
odr.disable = false
// expect all retrievals to pass
test(5)
odr.disable = true odr.disable = true
test(1)
// expect all retrievals to pass with ODR enabled
t.Log("checking with ODR")
odr.disable = false
test(len(gchain))
// still expect all retrievals to pass, now data should be cached locally // still expect all retrievals to pass, now data should be cached locally
test(5) t.Log("checking without ODR, should be cached")
odr.disable = true
test(len(gchain))
} }

View File

@ -106,25 +106,6 @@ func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (commo
return common.Hash{}, err return common.Hash{}, err
} }
// retrieveContractCode tries to retrieve the contract code of the given account
// with the given hash from the network (id points to the storage trie belonging
// to the same account)
func retrieveContractCode(ctx context.Context, odr OdrBackend, id *TrieID, hash common.Hash) ([]byte, error) {
if hash == sha3_nil {
return nil, nil
}
res, _ := odr.Database().Get(hash[:])
if res != nil {
return res, nil
}
r := &CodeRequest{Id: id, Hash: hash}
if err := odr.Retrieve(ctx, r); err != nil {
return nil, err
} else {
return r.Data, nil
}
}
// GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding. // GetBodyRLP retrieves the block body (transactions and uncles) in RLP encoding.
func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) { func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) (rlp.RawValue, error) {
if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil { if data := core.GetBodyRLP(odr.Database(), hash, number); data != nil {

View File

@ -1,316 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package light
import (
"context"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)
// LightState is a memory representation of a state.
// This version is ODR capable, caching only the already accessed part of the
// state, retrieving unknown parts on-demand from the ODR backend. Changes are
// never stored in the local database, only in the memory objects.
type LightState struct {
odr OdrBackend
trie *LightTrie
id *TrieID
stateObjects map[string]*StateObject
refund *big.Int
}
// NewLightState creates a new LightState with the specified root.
// Note that the creation of a light state is always successful, even if the
// root is non-existent. In that case, ODR retrieval will always be unsuccessful
// and every operation will return with an error or wait for the context to be
// cancelled.
func NewLightState(id *TrieID, odr OdrBackend) *LightState {
var tr *LightTrie
if id != nil {
tr = NewLightTrie(id, odr, true)
}
return &LightState{
odr: odr,
trie: tr,
id: id,
stateObjects: make(map[string]*StateObject),
refund: new(big.Int),
}
}
// AddRefund adds an amount to the refund value collected during a vm execution
func (self *LightState) AddRefund(gas *big.Int) {
self.refund.Add(self.refund, gas)
}
// HasAccount returns true if an account exists at the given address
func (self *LightState) HasAccount(ctx context.Context, addr common.Address) (bool, error) {
so, err := self.GetStateObject(ctx, addr)
return so != nil, err
}
// GetBalance retrieves the balance from the given address or 0 if the account does
// not exist
func (self *LightState) GetBalance(ctx context.Context, addr common.Address) (*big.Int, error) {
stateObject, err := self.GetStateObject(ctx, addr)
if err != nil {
return common.Big0, err
}
if stateObject != nil {
return stateObject.balance, nil
}
return common.Big0, nil
}
// GetNonce returns the nonce at the given address or 0 if the account does
// not exist
func (self *LightState) GetNonce(ctx context.Context, addr common.Address) (uint64, error) {
stateObject, err := self.GetStateObject(ctx, addr)
if err != nil {
return 0, err
}
if stateObject != nil {
return stateObject.nonce, nil
}
return 0, nil
}
// GetCode returns the contract code at the given address or nil if the account
// does not exist
func (self *LightState) GetCode(ctx context.Context, addr common.Address) ([]byte, error) {
stateObject, err := self.GetStateObject(ctx, addr)
if err != nil {
return nil, err
}
if stateObject != nil {
return stateObject.code, nil
}
return nil, nil
}
// GetState returns the contract storage value at storage address b from the
// contract address a or common.Hash{} if the account does not exist
func (self *LightState) GetState(ctx context.Context, a common.Address, b common.Hash) (common.Hash, error) {
stateObject, err := self.GetStateObject(ctx, a)
if err == nil && stateObject != nil {
return stateObject.GetState(ctx, b)
}
return common.Hash{}, err
}
// HasSuicided returns true if the given account has been marked for deletion
// or false if the account does not exist
func (self *LightState) HasSuicided(ctx context.Context, addr common.Address) (bool, error) {
stateObject, err := self.GetStateObject(ctx, addr)
if err == nil && stateObject != nil {
return stateObject.remove, nil
}
return false, err
}
/*
* SETTERS
*/
// AddBalance adds the given amount to the balance of the specified account
func (self *LightState) AddBalance(ctx context.Context, addr common.Address, amount *big.Int) error {
stateObject, err := self.GetOrNewStateObject(ctx, addr)
if err == nil && stateObject != nil {
stateObject.AddBalance(amount)
}
return err
}
// SubBalance adds the given amount to the balance of the specified account
func (self *LightState) SubBalance(ctx context.Context, addr common.Address, amount *big.Int) error {
stateObject, err := self.GetOrNewStateObject(ctx, addr)
if err == nil && stateObject != nil {
stateObject.SubBalance(amount)
}
return err
}
// SetNonce sets the nonce of the specified account
func (self *LightState) SetNonce(ctx context.Context, addr common.Address, nonce uint64) error {
stateObject, err := self.GetOrNewStateObject(ctx, addr)
if err == nil && stateObject != nil {
stateObject.SetNonce(nonce)
}
return err
}
// SetCode sets the contract code at the specified account
func (self *LightState) SetCode(ctx context.Context, addr common.Address, code []byte) error {
stateObject, err := self.GetOrNewStateObject(ctx, addr)
if err == nil && stateObject != nil {
stateObject.SetCode(crypto.Keccak256Hash(code), code)
}
return err
}
// SetState sets the storage value at storage address key of the account addr
func (self *LightState) SetState(ctx context.Context, addr common.Address, key common.Hash, value common.Hash) error {
stateObject, err := self.GetOrNewStateObject(ctx, addr)
if err == nil && stateObject != nil {
stateObject.SetState(key, value)
}
return err
}
// Delete marks an account to be removed and clears its balance
func (self *LightState) Suicide(ctx context.Context, addr common.Address) (bool, error) {
stateObject, err := self.GetOrNewStateObject(ctx, addr)
if err == nil && stateObject != nil {
stateObject.MarkForDeletion()
stateObject.balance = new(big.Int)
return true, nil
}
return false, err
}
//
// Get, set, new state object methods
//
// GetStateObject returns the state object of the given account or nil if the
// account does not exist
func (self *LightState) GetStateObject(ctx context.Context, addr common.Address) (stateObject *StateObject, err error) {
stateObject = self.stateObjects[addr.Str()]
if stateObject != nil {
if stateObject.deleted {
stateObject = nil
}
return stateObject, nil
}
data, err := self.trie.Get(ctx, addr[:])
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, nil
}
stateObject, err = DecodeObject(ctx, self.id, addr, self.odr, []byte(data))
if err != nil {
return nil, err
}
self.SetStateObject(stateObject)
return stateObject, nil
}
// SetStateObject sets the state object of the given account
func (self *LightState) SetStateObject(object *StateObject) {
self.stateObjects[object.Address().Str()] = object
}
// GetOrNewStateObject returns the state object of the given account or creates a
// new one if the account does not exist
func (self *LightState) GetOrNewStateObject(ctx context.Context, addr common.Address) (*StateObject, error) {
stateObject, err := self.GetStateObject(ctx, addr)
if err == nil && (stateObject == nil || stateObject.deleted) {
stateObject, err = self.CreateStateObject(ctx, addr)
}
return stateObject, err
}
// newStateObject creates a state object whether it exists in the state or not
func (self *LightState) newStateObject(addr common.Address) *StateObject {
stateObject := NewStateObject(addr, self.odr)
self.stateObjects[addr.Str()] = stateObject
return stateObject
}
// CreateStateObject creates creates a new state object and takes ownership.
// This is different from "NewStateObject"
func (self *LightState) CreateStateObject(ctx context.Context, addr common.Address) (*StateObject, error) {
// Get previous (if any)
so, err := self.GetStateObject(ctx, addr)
if err != nil {
return nil, err
}
// Create a new one
newSo := self.newStateObject(addr)
// If it existed set the balance to the new account
if so != nil {
newSo.balance = so.balance
}
return newSo, nil
}
// ForEachStorage calls a callback function for every key/value pair found
// in the local storage cache. Note that unlike core/state.StateObject,
// light.StateObject only returns cached values and doesn't download the
// entire storage tree.
func (self *LightState) ForEachStorage(ctx context.Context, addr common.Address, cb func(key, value common.Hash) bool) error {
so, err := self.GetStateObject(ctx, addr)
if err != nil {
return err
}
if so == nil {
return nil
}
for h, v := range so.storage {
cb(h, v)
}
return nil
}
//
// Setting, copying of the state methods
//
// Copy creates a copy of the state
func (self *LightState) Copy() *LightState {
// ignore error - we assume state-to-be-copied always exists
state := NewLightState(nil, self.odr)
state.trie = self.trie
state.id = self.id
for k, stateObject := range self.stateObjects {
if stateObject.dirty {
state.stateObjects[k] = stateObject.Copy()
}
}
state.refund.Set(self.refund)
return state
}
// Set copies the contents of the given state onto this state, overwriting
// its contents
func (self *LightState) Set(state *LightState) {
self.trie = state.trie
self.stateObjects = state.stateObjects
self.refund = state.refund
}
// GetRefund returns the refund value collected during a vm execution
func (self *LightState) GetRefund() *big.Int {
return self.refund
}

View File

@ -1,275 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package light
import (
"bytes"
"context"
"fmt"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/rlp"
)
var emptyCodeHash = crypto.Keccak256(nil)
// Code represents a contract code in binary form
type Code []byte
// String returns a string representation of the code
func (self Code) String() string {
return string(self) //strings.Join(Disassemble(self), " ")
}
// Storage is a memory map cache of a contract storage
type Storage map[common.Hash]common.Hash
// String returns a string representation of the storage cache
func (self Storage) String() (str string) {
for key, value := range self {
str += fmt.Sprintf("%X : %X\n", key, value)
}
return
}
// Copy copies the contents of a storage cache
func (self Storage) Copy() Storage {
cpy := make(Storage)
for key, value := range self {
cpy[key] = value
}
return cpy
}
// StateObject is a memory representation of an account or contract and its storage.
// This version is ODR capable, caching only the already accessed part of the
// storage, retrieving unknown parts on-demand from the ODR backend. Changes are
// never stored in the local database, only in the memory objects.
type StateObject struct {
odr OdrBackend
trie *LightTrie
// Address belonging to this account
address common.Address
// The balance of the account
balance *big.Int
// The nonce of the account
nonce uint64
// The code hash if code is present (i.e. a contract)
codeHash []byte
// The code for this account
code Code
// Cached storage (flushed when updated)
storage Storage
// Mark for deletion
// When an object is marked for deletion it will be delete from the trie
// during the "update" phase of the state transition
remove bool
deleted bool
dirty bool
}
// NewStateObject creates a new StateObject of the specified account address
func NewStateObject(address common.Address, odr OdrBackend) *StateObject {
object := &StateObject{
odr: odr,
address: address,
balance: new(big.Int),
dirty: true,
codeHash: emptyCodeHash,
storage: make(Storage),
}
object.trie = NewLightTrie(&TrieID{}, odr, true)
return object
}
// MarkForDeletion marks an account to be removed
func (self *StateObject) MarkForDeletion() {
self.remove = true
self.dirty = true
}
// getAddr gets the storage value at the given address from the trie
func (c *StateObject) getAddr(ctx context.Context, addr common.Hash) (common.Hash, error) {
var ret []byte
val, err := c.trie.Get(ctx, addr[:])
if err != nil {
return common.Hash{}, err
}
rlp.DecodeBytes(val, &ret)
return common.BytesToHash(ret), nil
}
// Storage returns the storage cache object of the account
func (self *StateObject) Storage() Storage {
return self.storage
}
// GetState returns the storage value at the given address from either the cache
// or the trie
func (self *StateObject) GetState(ctx context.Context, key common.Hash) (common.Hash, error) {
value, exists := self.storage[key]
if !exists {
var err error
value, err = self.getAddr(ctx, key)
if err != nil {
return common.Hash{}, err
}
if (value != common.Hash{}) {
self.storage[key] = value
}
}
return value, nil
}
// SetState sets the storage value at the given address
func (self *StateObject) SetState(k, value common.Hash) {
self.storage[k] = value
self.dirty = true
}
// AddBalance adds the given amount to the account balance
func (c *StateObject) AddBalance(amount *big.Int) {
c.SetBalance(new(big.Int).Add(c.balance, amount))
}
// SubBalance subtracts the given amount from the account balance
func (c *StateObject) SubBalance(amount *big.Int) {
c.SetBalance(new(big.Int).Sub(c.balance, amount))
}
// SetBalance sets the account balance to the given amount
func (c *StateObject) SetBalance(amount *big.Int) {
c.balance = amount
c.dirty = true
}
// ReturnGas returns the gas back to the origin. Used by the Virtual machine or Closures
func (c *StateObject) ReturnGas(gas *big.Int) {}
// Copy creates a copy of the state object
func (self *StateObject) Copy() *StateObject {
stateObject := NewStateObject(self.Address(), self.odr)
stateObject.balance.Set(self.balance)
stateObject.codeHash = common.CopyBytes(self.codeHash)
stateObject.nonce = self.nonce
stateObject.trie = self.trie
stateObject.code = self.code
stateObject.storage = self.storage.Copy()
stateObject.remove = self.remove
stateObject.dirty = self.dirty
stateObject.deleted = self.deleted
return stateObject
}
//
// Attribute accessors
//
// empty returns whether the account is considered empty.
func (self *StateObject) empty() bool {
return self.nonce == 0 && self.balance.Sign() == 0 && bytes.Equal(self.codeHash, emptyCodeHash)
}
// Balance returns the account balance
func (self *StateObject) Balance() *big.Int {
return self.balance
}
// Address returns the address of the contract/account
func (self *StateObject) Address() common.Address {
return self.address
}
// Code returns the contract code
func (self *StateObject) Code() []byte {
return self.code
}
// SetCode sets the contract code
func (self *StateObject) SetCode(hash common.Hash, code []byte) {
self.code = code
self.codeHash = hash[:]
self.dirty = true
}
// SetNonce sets the account nonce
func (self *StateObject) SetNonce(nonce uint64) {
self.nonce = nonce
self.dirty = true
}
// Nonce returns the account nonce
func (self *StateObject) Nonce() uint64 {
return self.nonce
}
// ForEachStorage calls a callback function for every key/value pair found
// in the local storage cache. Note that unlike core/state.StateObject,
// light.StateObject only returns cached values and doesn't download the
// entire storage tree.
func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) {
for h, v := range self.storage {
cb(h, v)
}
}
// Never called, but must be present to allow StateObject to be used
// as a vm.Account interface that also satisfies the vm.ContractRef
// interface. Interfaces are awesome.
func (self *StateObject) Value() *big.Int {
panic("Value on StateObject should never be called")
}
// Encoding
type extStateObject struct {
Nonce uint64
Balance *big.Int
Root common.Hash
CodeHash []byte
}
// DecodeObject decodes an RLP-encoded state object.
func DecodeObject(ctx context.Context, stateID *TrieID, address common.Address, odr OdrBackend, data []byte) (*StateObject, error) {
var (
obj = &StateObject{address: address, odr: odr, storage: make(Storage)}
ext extStateObject
err error
)
if err = rlp.DecodeBytes(data, &ext); err != nil {
return nil, err
}
trieID := StorageTrieID(stateID, address, ext.Root)
obj.trie = NewLightTrie(trieID, odr, true)
if !bytes.Equal(ext.CodeHash, emptyCodeHash) {
if obj.code, err = retrieveContractCode(ctx, obj.odr, trieID, common.BytesToHash(ext.CodeHash)); err != nil {
return nil, fmt.Errorf("can't find code for hash %x: %v", ext.CodeHash, err)
}
}
obj.nonce = ext.Nonce
obj.balance = ext.Balance
obj.codeHash = ext.CodeHash
return obj, nil
}

View File

@ -1,248 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package light
import (
"bytes"
"context"
"math/big"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb"
)
func makeTestState() (common.Hash, ethdb.Database) {
sdb, _ := ethdb.NewMemDatabase()
st, _ := state.New(common.Hash{}, sdb)
for i := byte(0); i < 100; i++ {
addr := common.Address{i}
for j := byte(0); j < 100; j++ {
st.SetState(addr, common.Hash{j}, common.Hash{i, j})
}
st.SetNonce(addr, 100)
st.AddBalance(addr, big.NewInt(int64(i)))
st.SetCode(addr, []byte{i, i, i})
}
root, _ := st.Commit(false)
return root, sdb
}
func TestLightStateOdr(t *testing.T) {
root, sdb := makeTestState()
header := &types.Header{Root: root, Number: big.NewInt(0)}
core.WriteHeader(sdb, header)
ldb, _ := ethdb.NewMemDatabase()
odr := &testOdr{sdb: sdb, ldb: ldb}
ls := NewLightState(StateTrieID(header), odr)
ctx := context.Background()
for i := byte(0); i < 100; i++ {
addr := common.Address{i}
err := ls.AddBalance(ctx, addr, big.NewInt(1000))
if err != nil {
t.Fatalf("Error adding balance to acc[%d]: %v", i, err)
}
err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100})
if err != nil {
t.Fatalf("Error setting storage of acc[%d]: %v", i, err)
}
}
addr := common.Address{100}
_, err := ls.CreateStateObject(ctx, addr)
if err != nil {
t.Fatalf("Error creating state object: %v", err)
}
err = ls.SetCode(ctx, addr, []byte{100, 100, 100})
if err != nil {
t.Fatalf("Error setting code: %v", err)
}
err = ls.AddBalance(ctx, addr, big.NewInt(1100))
if err != nil {
t.Fatalf("Error adding balance to acc[100]: %v", err)
}
for j := byte(0); j < 101; j++ {
err = ls.SetState(ctx, addr, common.Hash{j}, common.Hash{100, j})
if err != nil {
t.Fatalf("Error setting storage of acc[100]: %v", err)
}
}
err = ls.SetNonce(ctx, addr, 100)
if err != nil {
t.Fatalf("Error setting nonce for acc[100]: %v", err)
}
for i := byte(0); i < 101; i++ {
addr := common.Address{i}
bal, err := ls.GetBalance(ctx, addr)
if err != nil {
t.Fatalf("Error getting balance of acc[%d]: %v", i, err)
}
if bal.Int64() != int64(i)+1000 {
t.Fatalf("Incorrect balance at acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64())
}
nonce, err := ls.GetNonce(ctx, addr)
if err != nil {
t.Fatalf("Error getting nonce of acc[%d]: %v", i, err)
}
if nonce != 100 {
t.Fatalf("Incorrect nonce at acc[%d]: expected %v, got %v", i, 100, nonce)
}
code, err := ls.GetCode(ctx, addr)
exp := []byte{i, i, i}
if err != nil {
t.Fatalf("Error getting code of acc[%d]: %v", i, err)
}
if !bytes.Equal(code, exp) {
t.Fatalf("Incorrect code at acc[%d]: expected %v, got %v", i, exp, code)
}
for j := byte(0); j < 101; j++ {
exp := common.Hash{i, j}
val, err := ls.GetState(ctx, addr, common.Hash{j})
if err != nil {
t.Fatalf("Error retrieving acc[%d].storage[%d]: %v", i, j, err)
}
if val != exp {
t.Fatalf("Retrieved wrong value from acc[%d].storage[%d]: expected %04x, got %04x", i, j, exp, val)
}
}
}
}
func TestLightStateSetCopy(t *testing.T) {
root, sdb := makeTestState()
header := &types.Header{Root: root, Number: big.NewInt(0)}
core.WriteHeader(sdb, header)
ldb, _ := ethdb.NewMemDatabase()
odr := &testOdr{sdb: sdb, ldb: ldb}
ls := NewLightState(StateTrieID(header), odr)
ctx := context.Background()
for i := byte(0); i < 100; i++ {
addr := common.Address{i}
err := ls.AddBalance(ctx, addr, big.NewInt(1000))
if err != nil {
t.Fatalf("Error adding balance to acc[%d]: %v", i, err)
}
err = ls.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 100})
if err != nil {
t.Fatalf("Error setting storage of acc[%d]: %v", i, err)
}
}
ls2 := ls.Copy()
for i := byte(0); i < 100; i++ {
addr := common.Address{i}
err := ls2.AddBalance(ctx, addr, big.NewInt(1000))
if err != nil {
t.Fatalf("Error adding balance to acc[%d]: %v", i, err)
}
err = ls2.SetState(ctx, addr, common.Hash{100}, common.Hash{i, 200})
if err != nil {
t.Fatalf("Error setting storage of acc[%d]: %v", i, err)
}
}
lsx := ls.Copy()
ls.Set(ls2)
ls2.Set(lsx)
for i := byte(0); i < 100; i++ {
addr := common.Address{i}
// check balance in ls
bal, err := ls.GetBalance(ctx, addr)
if err != nil {
t.Fatalf("Error getting balance to acc[%d]: %v", i, err)
}
if bal.Int64() != int64(i)+2000 {
t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64())
}
// check balance in ls2
bal, err = ls2.GetBalance(ctx, addr)
if err != nil {
t.Fatalf("Error getting balance to acc[%d]: %v", i, err)
}
if bal.Int64() != int64(i)+1000 {
t.Fatalf("Incorrect balance at ls.acc[%d]: expected %v, got %v", i, int64(i)+1000, bal.Int64())
}
// check storage in ls
exp := common.Hash{i, 200}
val, err := ls.GetState(ctx, addr, common.Hash{100})
if err != nil {
t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err)
}
if val != exp {
t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val)
}
// check storage in ls2
exp = common.Hash{i, 100}
val, err = ls2.GetState(ctx, addr, common.Hash{100})
if err != nil {
t.Fatalf("Error retrieving acc[%d].storage[100]: %v", i, err)
}
if val != exp {
t.Fatalf("Retrieved wrong value from acc[%d].storage[100]: expected %04x, got %04x", i, exp, val)
}
}
}
func TestLightStateDelete(t *testing.T) {
root, sdb := makeTestState()
header := &types.Header{Root: root, Number: big.NewInt(0)}
core.WriteHeader(sdb, header)
ldb, _ := ethdb.NewMemDatabase()
odr := &testOdr{sdb: sdb, ldb: ldb}
ls := NewLightState(StateTrieID(header), odr)
ctx := context.Background()
addr := common.Address{42}
b, err := ls.HasAccount(ctx, addr)
if err != nil {
t.Fatalf("HasAccount error: %v", err)
}
if !b {
t.Fatalf("HasAccount returned false, expected true")
}
b, err = ls.HasSuicided(ctx, addr)
if err != nil {
t.Fatalf("HasSuicided error: %v", err)
}
if b {
t.Fatalf("HasSuicided returned true, expected false")
}
ls.Suicide(ctx, addr)
b, err = ls.HasSuicided(ctx, addr)
if err != nil {
t.Fatalf("HasSuicided error: %v", err)
}
if !b {
t.Fatalf("HasSuicided returned false, expected true")
}
}

View File

@ -18,99 +18,216 @@ package light
import ( import (
"context" "context"
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/trie" "github.com/ethereum/go-ethereum/trie"
) )
// LightTrie is an ODR-capable wrapper around trie.SecureTrie func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB {
type LightTrie struct { state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr))
trie *trie.SecureTrie return state
id *TrieID
odr OdrBackend
db ethdb.Database
} }
// NewLightTrie creates a new LightTrie instance. It doesn't instantly try to func NewStateDatabase(ctx context.Context, head *types.Header, odr OdrBackend) state.Database {
// access the db or network and retrieve the root node, it only initializes its return &odrDatabase{ctx, StateTrieID(head), odr}
// encapsulated SecureTrie at the first actual operation. }
func NewLightTrie(id *TrieID, odr OdrBackend, useFakeMap bool) *LightTrie {
return &LightTrie{ type odrDatabase struct {
// SecureTrie is initialized before first request ctx context.Context
id: id, id *TrieID
odr: odr, backend OdrBackend
db: odr.Database(), }
func (db *odrDatabase) OpenTrie(root common.Hash) (state.Trie, error) {
return &odrTrie{db: db, id: db.id}, nil
}
func (db *odrDatabase) OpenStorageTrie(addrHash, root common.Hash) (state.Trie, error) {
return &odrTrie{db: db, id: StorageTrieID(db.id, addrHash, root)}, nil
}
func (db *odrDatabase) CopyTrie(t state.Trie) state.Trie {
switch t := t.(type) {
case *odrTrie:
cpy := &odrTrie{db: t.db, id: t.id}
if t.trie != nil {
cpytrie := *t.trie
cpy.trie = &cpytrie
}
return cpy
default:
panic(fmt.Errorf("unknown trie type %T", t))
} }
} }
// retrieveKey retrieves a single key, returns true and stores nodes in local func (db *odrDatabase) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
// database if successful if codeHash == sha3_nil {
func (t *LightTrie) retrieveKey(ctx context.Context, key []byte) bool { return nil, nil
r := &TrieRequest{Id: t.id, Key: crypto.Keccak256(key)} }
return t.odr.Retrieve(ctx, r) == nil if code, err := db.backend.Database().Get(codeHash[:]); err == nil {
return code, nil
}
id := *db.id
id.AccKey = addrHash[:]
req := &CodeRequest{Id: &id, Hash: codeHash}
err := db.backend.Retrieve(db.ctx, req)
return req.Data, err
}
func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
code, err := db.ContractCode(addrHash, codeHash)
return len(code), err
}
type odrTrie struct {
db *odrDatabase
id *TrieID
trie *trie.Trie
}
func (t *odrTrie) TryGet(key []byte) ([]byte, error) {
key = crypto.Keccak256(key)
var res []byte
err := t.do(key, func() (err error) {
res, err = t.trie.TryGet(key)
return err
})
return res, err
}
func (t *odrTrie) TryUpdate(key, value []byte) error {
key = crypto.Keccak256(key)
return t.do(key, func() error {
return t.trie.TryDelete(key)
})
}
func (t *odrTrie) TryDelete(key []byte) error {
key = crypto.Keccak256(key)
return t.do(key, func() error {
return t.trie.TryDelete(key)
})
}
func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) {
if t.trie == nil {
return t.id.Root, nil
}
return t.trie.CommitTo(db)
}
func (t *odrTrie) Hash() common.Hash {
if t.trie == nil {
return t.id.Root
}
return t.trie.Hash()
}
func (t *odrTrie) NodeIterator(startkey []byte) trie.NodeIterator {
return newNodeIterator(t, startkey)
}
func (t *odrTrie) GetKey(sha []byte) []byte {
return nil
} }
// do tries and retries to execute a function until it returns with no error or // do tries and retries to execute a function until it returns with no error or
// an error type other than MissingNodeError // an error type other than MissingNodeError
func (t *LightTrie) do(ctx context.Context, key []byte, fn func() error) error { func (t *odrTrie) do(key []byte, fn func() error) error {
err := fn() for {
for err != nil { var err error
if t.trie == nil {
t.trie, err = trie.New(t.id.Root, t.db.backend.Database())
}
if err == nil {
err = fn()
}
if _, ok := err.(*trie.MissingNodeError); !ok { if _, ok := err.(*trie.MissingNodeError); !ok {
return err return err
} }
if !t.retrieveKey(ctx, key) { r := &TrieRequest{Id: t.id, Key: key}
break if err := t.db.backend.Retrieve(t.db.ctx, r); err != nil {
return fmt.Errorf("can't fetch trie key %x: %v", key, err)
} }
err = fn() }
}
type nodeIterator struct {
trie.NodeIterator
t *odrTrie
err error
}
func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator {
it := &nodeIterator{t: t}
// Open the actual non-ODR trie if that hasn't happened yet.
if t.trie == nil {
it.do(func() error {
t, err := trie.New(t.id.Root, t.db.backend.Database())
if err == nil {
it.t.trie = t
} }
return err return err
})
}
it.do(func() error {
it.NodeIterator = it.t.trie.NodeIterator(startkey)
return it.NodeIterator.Error()
})
return it
} }
// Get returns the value for key stored in the trie. func (it *nodeIterator) Next(descend bool) bool {
// The value bytes must not be modified by the caller. var ok bool
func (t *LightTrie) Get(ctx context.Context, key []byte) (res []byte, err error) { it.do(func() error {
err = t.do(ctx, key, func() (err error) { ok = it.NodeIterator.Next(descend)
if t.trie == nil { return it.NodeIterator.Error()
t.trie, err = trie.NewSecure(t.id.Root, t.db, 0)
}
if err == nil {
res, err = t.trie.TryGet(key)
}
return
}) })
return return ok
} }
// Update associates key with value in the trie. Subsequent calls to // do runs fn and attempts to fill in missing nodes by retrieving.
// Get will return value. If value has length zero, any existing value func (it *nodeIterator) do(fn func() error) {
// is deleted from the trie and calls to Get will return nil. var lasthash common.Hash
// for {
// The value bytes must not be modified by the caller while they are it.err = fn()
// stored in the trie. missing, ok := it.err.(*trie.MissingNodeError)
func (t *LightTrie) Update(ctx context.Context, key, value []byte) (err error) { if !ok {
err = t.do(ctx, key, func() (err error) {
if t.trie == nil {
t.trie, err = trie.NewSecure(t.id.Root, t.db, 0)
}
if err == nil {
err = t.trie.TryUpdate(key, value)
}
return return
}) }
if missing.NodeHash == lasthash {
it.err = fmt.Errorf("retrieve loop for trie node %x", missing.NodeHash)
return return
}
lasthash = missing.NodeHash
r := &TrieRequest{Id: it.t.id, Key: nibblesToKey(missing.Path)}
if it.err = it.t.db.backend.Retrieve(it.t.db.ctx, r); it.err != nil {
return
}
}
} }
// Delete removes any existing value for key from the trie. func (it *nodeIterator) Error() error {
func (t *LightTrie) Delete(ctx context.Context, key []byte) (err error) { if it.err != nil {
err = t.do(ctx, key, func() (err error) { return it.err
if t.trie == nil {
t.trie, err = trie.NewSecure(t.id.Root, t.db, 0)
} }
if err == nil { return it.NodeIterator.Error()
err = t.trie.TryDelete(key) }
}
return func nibblesToKey(nib []byte) []byte {
}) if len(nib) > 0 && nib[len(nib)-1] == 0x10 {
return nib = nib[:len(nib)-1] // drop terminator
}
if len(nib)&1 == 1 {
nib = append(nib, 0) // make even
}
key := make([]byte, len(nib)/2)
for bi, ni := 0, 0; ni < len(nib); bi, ni = bi+1, ni+2 {
key[bi] = nib[ni]<<4 | nib[ni+1]
}
return key
} }

83
light/trie_test.go Normal file
View File

@ -0,0 +1,83 @@
// Copyright 2017 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package light
import (
"bytes"
"context"
"fmt"
"testing"
"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/consensus/ethash"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/trie"
)
func TestNodeIterator(t *testing.T) {
var (
fulldb, _ = ethdb.NewMemDatabase()
lightdb, _ = ethdb.NewMemDatabase()
gspec = core.Genesis{Alloc: core.GenesisAlloc{testBankAddress: {Balance: testBankFunds}}}
genesis = gspec.MustCommit(fulldb)
)
gspec.MustCommit(lightdb)
blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), new(event.TypeMux), vm.Config{})
gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, fulldb, 4, testChainGen)
if _, err := blockchain.InsertChain(gchain); err != nil {
panic(err)
}
ctx := context.Background()
odr := &testOdr{sdb: fulldb, ldb: lightdb}
head := blockchain.CurrentHeader()
lightTrie, _ := NewStateDatabase(ctx, head, odr).OpenTrie(head.Root)
fullTrie, _ := state.NewDatabase(fulldb).OpenTrie(head.Root)
if err := diffTries(fullTrie, lightTrie); err != nil {
t.Fatal(err)
}
}
func diffTries(t1, t2 state.Trie) error {
i1 := trie.NewIterator(t1.NodeIterator(nil))
i2 := trie.NewIterator(t2.NodeIterator(nil))
for i1.Next() && i2.Next() {
if !bytes.Equal(i1.Key, i2.Key) {
spew.Dump(i2)
return fmt.Errorf("tries have different keys %x, %x", i1.Key, i2.Key)
}
if !bytes.Equal(i2.Value, i2.Value) {
return fmt.Errorf("tries differ at key %x", i1.Key)
}
}
switch {
case i1.Err != nil:
return fmt.Errorf("full trie iterator error: %v", i1.Err)
case i2.Err != nil:
return fmt.Errorf("light trie iterator error: %v", i1.Err)
case i1.Next():
return fmt.Errorf("full trie iterator has more k/v pairs")
case i2.Next():
return fmt.Errorf("light trie iterator has more k/v pairs")
}
return nil
}

View File

@ -24,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
@ -100,17 +101,18 @@ func NewTxPool(config *params.ChainConfig, eventMux *event.TypeMux, chain *Light
} }
// currentState returns the light state of the current head header // currentState returns the light state of the current head header
func (pool *TxPool) currentState() *LightState { func (pool *TxPool) currentState(ctx context.Context) *state.StateDB {
return NewLightState(StateTrieID(pool.chain.CurrentHeader()), pool.odr) return NewState(ctx, pool.chain.CurrentHeader(), pool.odr)
} }
// GetNonce returns the "pending" nonce of a given address. It always queries // GetNonce returns the "pending" nonce of a given address. It always queries
// the nonce belonging to the latest header too in order to detect if another // the nonce belonging to the latest header too in order to detect if another
// client using the same key sent a transaction. // client using the same key sent a transaction.
func (pool *TxPool) GetNonce(ctx context.Context, addr common.Address) (uint64, error) { func (pool *TxPool) GetNonce(ctx context.Context, addr common.Address) (uint64, error) {
nonce, err := pool.currentState().GetNonce(ctx, addr) state := pool.currentState(ctx)
if err != nil { nonce := state.GetNonce(addr)
return 0, err if state.Error() != nil {
return 0, state.Error()
} }
sn, ok := pool.nonce[addr] sn, ok := pool.nonce[addr]
if ok && sn > nonce { if ok && sn > nonce {
@ -357,14 +359,10 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error
return core.ErrInvalidSender return core.ErrInvalidSender
} }
// Last but not least check for nonce errors // Last but not least check for nonce errors
currentState := pool.currentState() currentState := pool.currentState(ctx)
if n, err := currentState.GetNonce(ctx, from); err == nil { if n := currentState.GetNonce(from); n > tx.Nonce() {
if n > tx.Nonce() {
return core.ErrNonceTooLow return core.ErrNonceTooLow
} }
} else {
return err
}
// Check the transaction doesn't exceed the current // Check the transaction doesn't exceed the current
// block limit gas. // block limit gas.
@ -382,20 +380,16 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error
// Transactor should have enough funds to cover the costs // Transactor should have enough funds to cover the costs
// cost == V + GP * GL // cost == V + GP * GL
if b, err := currentState.GetBalance(ctx, from); err == nil { if b := currentState.GetBalance(from); b.Cmp(tx.Cost()) < 0 {
if b.Cmp(tx.Cost()) < 0 {
return core.ErrInsufficientFunds return core.ErrInsufficientFunds
} }
} else {
return err
}
// Should supply enough intrinsic gas // Should supply enough intrinsic gas
if tx.Gas().Cmp(core.IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead)) < 0 { if tx.Gas().Cmp(core.IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead)) < 0 {
return core.ErrIntrinsicGas return core.ErrIntrinsicGas
} }
return nil return currentState.Error()
} }
// add validates a new transaction and sets its state pending if processable. // add validates a new transaction and sets its state pending if processable.

View File

@ -1,194 +0,0 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package light
import (
"context"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
)
// VMState is a wrapper for the light state that holds the actual context and
// passes it to any state operation that requires it.
type VMState struct {
ctx context.Context
state *LightState
snapshots []*LightState
err error
}
func NewVMState(ctx context.Context, state *LightState) *VMState {
return &VMState{ctx: ctx, state: state}
}
func (s *VMState) Error() error {
return s.err
}
func (s *VMState) AddLog(log *types.Log) {}
func (s *VMState) AddPreimage(hash common.Hash, preimage []byte) {}
// errHandler handles and stores any state error that happens during execution.
func (s *VMState) errHandler(err error) {
if err != nil && s.err == nil {
s.err = err
}
}
func (self *VMState) Snapshot() int {
self.snapshots = append(self.snapshots, self.state.Copy())
return len(self.snapshots) - 1
}
func (self *VMState) RevertToSnapshot(idx int) {
self.state.Set(self.snapshots[idx])
self.snapshots = self.snapshots[:idx]
}
// CreateAccount creates creates a new account object and takes ownership.
func (s *VMState) CreateAccount(addr common.Address) {
_, err := s.state.CreateStateObject(s.ctx, addr)
s.errHandler(err)
}
// AddBalance adds the given amount to the balance of the specified account
func (s *VMState) AddBalance(addr common.Address, amount *big.Int) {
err := s.state.AddBalance(s.ctx, addr, amount)
s.errHandler(err)
}
// SubBalance adds the given amount to the balance of the specified account
func (s *VMState) SubBalance(addr common.Address, amount *big.Int) {
err := s.state.SubBalance(s.ctx, addr, amount)
s.errHandler(err)
}
// ForEachStorage calls a callback function for every key/value pair found
// in the local storage cache. Note that unlike core/state.StateObject,
// light.StateObject only returns cached values and doesn't download the
// entire storage tree.
func (s *VMState) ForEachStorage(addr common.Address, cb func(key, value common.Hash) bool) {
err := s.state.ForEachStorage(s.ctx, addr, cb)
s.errHandler(err)
}
// GetBalance retrieves the balance from the given address or 0 if the account does
// not exist
func (s *VMState) GetBalance(addr common.Address) *big.Int {
res, err := s.state.GetBalance(s.ctx, addr)
s.errHandler(err)
return res
}
// GetNonce returns the nonce at the given address or 0 if the account does
// not exist
func (s *VMState) GetNonce(addr common.Address) uint64 {
res, err := s.state.GetNonce(s.ctx, addr)
s.errHandler(err)
return res
}
// SetNonce sets the nonce of the specified account
func (s *VMState) SetNonce(addr common.Address, nonce uint64) {
err := s.state.SetNonce(s.ctx, addr, nonce)
s.errHandler(err)
}
// GetCode returns the contract code at the given address or nil if the account
// does not exist
func (s *VMState) GetCode(addr common.Address) []byte {
res, err := s.state.GetCode(s.ctx, addr)
s.errHandler(err)
return res
}
// GetCodeHash returns the contract code hash at the given address
func (s *VMState) GetCodeHash(addr common.Address) common.Hash {
res, err := s.state.GetCode(s.ctx, addr)
s.errHandler(err)
return crypto.Keccak256Hash(res)
}
// GetCodeSize returns the contract code size at the given address
func (s *VMState) GetCodeSize(addr common.Address) int {
res, err := s.state.GetCode(s.ctx, addr)
s.errHandler(err)
return len(res)
}
// SetCode sets the contract code at the specified account
func (s *VMState) SetCode(addr common.Address, code []byte) {
err := s.state.SetCode(s.ctx, addr, code)
s.errHandler(err)
}
// AddRefund adds an amount to the refund value collected during a vm execution
func (s *VMState) AddRefund(gas *big.Int) {
s.state.AddRefund(gas)
}
// GetRefund returns the refund value collected during a vm execution
func (s *VMState) GetRefund() *big.Int {
return s.state.GetRefund()
}
// GetState returns the contract storage value at storage address b from the
// contract address a or common.Hash{} if the account does not exist
func (s *VMState) GetState(a common.Address, b common.Hash) common.Hash {
res, err := s.state.GetState(s.ctx, a, b)
s.errHandler(err)
return res
}
// SetState sets the storage value at storage address key of the account addr
func (s *VMState) SetState(addr common.Address, key common.Hash, value common.Hash) {
err := s.state.SetState(s.ctx, addr, key, value)
s.errHandler(err)
}
// Suicide marks an account to be removed and clears its balance
func (s *VMState) Suicide(addr common.Address) bool {
res, err := s.state.Suicide(s.ctx, addr)
s.errHandler(err)
return res
}
// Exist returns true if an account exists at the given address
func (s *VMState) Exist(addr common.Address) bool {
res, err := s.state.HasAccount(s.ctx, addr)
s.errHandler(err)
return res
}
// Empty returns true if the account at the given address is considered empty
func (s *VMState) Empty(addr common.Address) bool {
so, err := s.state.GetStateObject(s.ctx, addr)
s.errHandler(err)
return so == nil || so.empty()
}
// HasSuicided returns true if the given account has been marked for deletion
// or false if the account does not exist
func (s *VMState) HasSuicided(addr common.Address) bool {
res, err := s.state.HasSuicided(s.ctx, addr)
s.errHandler(err)
return res
}

View File

@ -274,7 +274,7 @@ func (self *worker) wait() {
} }
go self.mux.Post(core.NewMinedBlockEvent{Block: block}) go self.mux.Post(core.NewMinedBlockEvent{Block: block})
} else { } else {
work.state.Commit(self.config.IsEIP158(block.Number())) work.state.CommitTo(self.chainDb, self.config.IsEIP158(block.Number()))
stat, err := self.chain.WriteBlock(block) stat, err := self.chain.WriteBlock(block)
if err != nil { if err != nil {
log.Error("Failed writing block to chain", "err", err) log.Error("Failed writing block to chain", "err", err)

View File

@ -204,7 +204,7 @@ func runBlockTest(homesteadBlock, daoForkBlock, gasPriceFork *big.Int, test *Blo
// InsertPreState populates the given database with the genesis // InsertPreState populates the given database with the genesis
// accounts defined by the test. // accounts defined by the test.
func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) { func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) {
statedb, err := state.New(common.Hash{}, db) statedb, err := state.New(common.Hash{}, state.NewDatabase(db))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -232,7 +232,7 @@ func (t *BlockTest) InsertPreState(db ethdb.Database) (*state.StateDB, error) {
} }
} }
root, err := statedb.Commit(false) root, err := statedb.CommitTo(db, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("error writing state: %v", err) return nil, fmt.Errorf("error writing state: %v", err)
} }

View File

@ -20,7 +20,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"math/big"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@ -99,7 +98,7 @@ func benchStateTest(chainConfig *params.ChainConfig, test VmTest, env map[string
statedb := makePreState(db, test.Pre) statedb := makePreState(db, test.Pre)
b.StartTimer() b.StartTimer()
RunState(chainConfig, statedb, env, test.Exec) RunState(chainConfig, statedb, db, env, test.Exec)
} }
func runStateTests(chainConfig *params.ChainConfig, tests map[string]VmTest, skipTests []string) error { func runStateTests(chainConfig *params.ChainConfig, tests map[string]VmTest, skipTests []string) error {
@ -143,16 +142,9 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error {
env["currentTimestamp"] = test.Env.CurrentTimestamp.(string) env["currentTimestamp"] = test.Env.CurrentTimestamp.(string)
} }
var ( ret, logs, root, _ := RunState(chainConfig, statedb, db, env, test.Transaction)
ret []byte
// gas *big.Int
// err error
logs []*types.Log
)
ret, logs, _, _ = RunState(chainConfig, statedb, env, test.Transaction) // Return value:
// Compare expected and actual return
var rexp []byte var rexp []byte
if strings.HasPrefix(test.Out, "#") { if strings.HasPrefix(test.Out, "#") {
n, _ := strconv.Atoi(test.Out[1:]) n, _ := strconv.Atoi(test.Out[1:])
@ -163,61 +155,43 @@ func runStateTest(chainConfig *params.ChainConfig, test VmTest) error {
if !bytes.Equal(rexp, ret) { if !bytes.Equal(rexp, ret) {
return fmt.Errorf("return failed. Expected %x, got %x\n", rexp, ret) return fmt.Errorf("return failed. Expected %x, got %x\n", rexp, ret)
} }
// Post state content:
// check post state
for addr, account := range test.Post { for addr, account := range test.Post {
address := common.HexToAddress(addr) address := common.HexToAddress(addr)
if !statedb.Exist(address) { if !statedb.Exist(address) {
return fmt.Errorf("did not find expected post-state account: %s", addr) return fmt.Errorf("did not find expected post-state account: %s", addr)
} }
if balance := statedb.GetBalance(address); balance.Cmp(math.MustParseBig256(account.Balance)) != 0 { if balance := statedb.GetBalance(address); balance.Cmp(math.MustParseBig256(account.Balance)) != 0 {
return fmt.Errorf("(%x) balance failed. Expected: %v have: %v\n", address[:4], math.MustParseBig256(account.Balance), balance) return fmt.Errorf("(%x) balance failed. Expected: %v have: %v\n", address[:4], math.MustParseBig256(account.Balance), balance)
} }
if nonce := statedb.GetNonce(address); nonce != math.MustParseUint64(account.Nonce) { if nonce := statedb.GetNonce(address); nonce != math.MustParseUint64(account.Nonce) {
return fmt.Errorf("(%x) nonce failed. Expected: %v have: %v\n", address[:4], account.Nonce, nonce) return fmt.Errorf("(%x) nonce failed. Expected: %v have: %v\n", address[:4], account.Nonce, nonce)
} }
for addr, value := range account.Storage { for addr, value := range account.Storage {
v := statedb.GetState(address, common.HexToHash(addr)) v := statedb.GetState(address, common.HexToHash(addr))
vexp := common.HexToHash(value) vexp := common.HexToHash(value)
if v != vexp { if v != vexp {
return fmt.Errorf("storage failed:\n%x: %s:\nexpected: %x\nhave: %x\n(%v %v)\n", address[:4], addr, vexp, v, vexp.Big(), v.Big()) return fmt.Errorf("storage failed:\n%x: %s:\nexpected: %x\nhave: %x\n(%v %v)\n", address[:4], addr, vexp, v, vexp.Big(), v.Big())
} }
} }
} }
// Root:
root, _ := statedb.Commit(false)
if common.HexToHash(test.PostStateRoot) != root { if common.HexToHash(test.PostStateRoot) != root {
return fmt.Errorf("Post state root error. Expected: %s have: %x", test.PostStateRoot, root) return fmt.Errorf("Post state root error. Expected: %s have: %x", test.PostStateRoot, root)
} }
// Logs:
// check logs return checkLogs(test.Logs, logs)
if len(test.Logs) > 0 {
if err := checkLogs(test.Logs, logs); err != nil {
return err
}
}
return nil
} }
func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, env, tx map[string]string) ([]byte, []*types.Log, *big.Int, error) { func RunState(chainConfig *params.ChainConfig, statedb *state.StateDB, db ethdb.Database, env, tx map[string]string) ([]byte, []*types.Log, common.Hash, error) {
environment, msg := NewEVMEnvironment(false, chainConfig, statedb, env, tx) environment, msg := NewEVMEnvironment(false, chainConfig, statedb, env, tx)
gaspool := new(core.GasPool).AddGas(math.MustParseBig256(env["currentGasLimit"])) gaspool := new(core.GasPool).AddGas(math.MustParseBig256(env["currentGasLimit"]))
root, _ := statedb.Commit(false)
statedb.Reset(root)
snapshot := statedb.Snapshot() snapshot := statedb.Snapshot()
ret, _, err := core.ApplyMessage(environment, msg, gaspool)
ret, gasUsed, err := core.ApplyMessage(environment, msg, gaspool)
if err != nil { if err != nil {
statedb.RevertToSnapshot(snapshot) statedb.RevertToSnapshot(snapshot)
} }
statedb.Commit(chainConfig.IsEIP158(environment.Context.BlockNumber)) root, _ := statedb.CommitTo(db, chainConfig.IsEIP158(environment.Context.BlockNumber))
return ret, statedb.Logs(), root, err
return ret, statedb.Logs(), gasUsed, err
} }

View File

@ -48,7 +48,6 @@ func init() {
} }
func checkLogs(tlog []Log, logs []*types.Log) error { func checkLogs(tlog []Log, logs []*types.Log) error {
if len(tlog) != len(logs) { if len(tlog) != len(logs) {
return fmt.Errorf("log length mismatch. Expected %d, got %d", len(tlog), len(logs)) return fmt.Errorf("log length mismatch. Expected %d, got %d", len(tlog), len(logs))
} else { } else {
@ -106,10 +105,14 @@ func (self Log) Topics() [][]byte {
} }
func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB { func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB {
statedb, _ := state.New(common.Hash{}, db) sdb := state.NewDatabase(db)
statedb, _ := state.New(common.Hash{}, sdb)
for addr, account := range accounts { for addr, account := range accounts {
insertAccount(statedb, addr, account) insertAccount(statedb, addr, account)
} }
// Commit and re-open to start with a clean state.
root, _ := statedb.CommitTo(db, false)
statedb, _ = state.New(root, sdb)
return statedb return statedb
} }

View File

@ -125,7 +125,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
} }
func get(tn node, key []byte) ([]byte, node) { func get(tn node, key []byte) ([]byte, node) {
for len(key) > 0 { for {
switch n := tn.(type) { switch n := tn.(type) {
case *shortNode: case *shortNode:
if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
@ -140,9 +140,10 @@ func get(tn node, key []byte) ([]byte, node) {
return key, n return key, n
case nil: case nil:
return key, nil return key, nil
case valueNode:
return nil, n
default: default:
panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
} }
} }
return nil, tn.(valueNode)
} }

View File

@ -156,6 +156,11 @@ func (t *SecureTrie) Root() []byte {
return t.trie.Root() return t.trie.Root()
} }
func (t *SecureTrie) Copy() *SecureTrie {
cpy := *t
return &cpy
}
// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key. // starts at the key after the given start key.
func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {